use super::*;
impl SurvivalLocationScaleFamily {
pub(crate) fn parametric_aft_states_from_theta(
&self,
theta: &Array1<f64>,
specs: &[ParameterBlockSpec],
) -> Result<Vec<ParameterBlockState>, String> {
let offsets = self.joint_block_offsets();
if theta.len() != *offsets.last().unwrap_or(&0) {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"parametric-AFT direct MLE theta length mismatch: got {}, expected {}",
theta.len(),
offsets.last().copied().unwrap_or(0)
),
}
.into());
}
let mut states = Vec::with_capacity(specs.len());
for (b, spec) in specs.iter().enumerate() {
let beta = theta.slice(s![offsets[b]..offsets[b + 1]]).to_owned();
let eta = spec.solver_design().matrixvectormultiply(&beta) + spec.solver_offset();
states.push(ParameterBlockState { beta, eta });
}
for b in 0..specs.len() {
let raw = states[b].beta.clone();
let projected = self.post_update_block_beta(&states, b, &specs[b], raw)?;
if projected != states[b].beta {
states[b].beta.assign(&projected);
states[b].eta = specs[b]
.solver_design()
.matrixvectormultiply(&states[b].beta)
+ specs[b].solver_offset();
}
}
Ok(states)
}
pub(crate) fn fit_parametric_aft_direct_mle(
&self,
specs: &[ParameterBlockSpec],
max_iter: usize,
grad_tol: f64,
) -> Result<(Vec<ParameterBlockState>, f64, Array2<f64>), String> {
use gam_linalg::faer_ndarray::FaerCholesky;
self.validate_joint_specs(
specs,
"SurvivalLocationScaleFamily direct parametric-AFT MLE",
)?;
let offsets = self.joint_block_offsets();
let p_total = *offsets.last().unwrap_or(&0);
if p_total == 0 {
return Err(SurvivalLocationScaleError::InvalidConfiguration {
reason: "direct parametric-AFT MLE has no free coefficients".to_string(),
}
.into());
}
let mut theta = Array1::<f64>::zeros(p_total);
for (b, spec) in specs.iter().enumerate() {
if let Some(beta0) = spec.initial_beta.as_ref() {
if beta0.len() != offsets[b + 1] - offsets[b] {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"direct parametric-AFT MLE block {b} initial_beta length {} != block width {}",
beta0.len(),
offsets[b + 1] - offsets[b]
),
}
.into());
}
theta
.slice_mut(s![offsets[b]..offsets[b + 1]])
.assign(beta0);
}
}
let mut states = self.parametric_aft_states_from_theta(&theta, specs)?;
for (b, state) in states.iter().enumerate() {
theta
.slice_mut(s![offsets[b]..offsets[b + 1]])
.assign(&state.beta);
}
let mut ll = self.log_likelihood_only(&states)?;
if !ll.is_finite() {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: format!(
"direct parametric-AFT MLE: non-finite initial log-likelihood {ll}"
),
}
.into());
}
for _ in 0..max_iter {
let (ll_now, block_gradients) =
self.evaluate_log_likelihood_and_block_gradients(&states)?;
ll = ll_now;
let mut g = Array1::<f64>::zeros(p_total);
if block_gradients.len() != specs.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"direct parametric-AFT MLE gradient block count mismatch: gradients={}, specs={}",
block_gradients.len(),
specs.len()
),
}
.into());
}
for (b, gb) in block_gradients.iter().enumerate() {
if gb.len() != offsets[b + 1] - offsets[b] {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"direct parametric-AFT MLE block {b} gradient length {} != block width {}",
gb.len(),
offsets[b + 1] - offsets[b]
),
}
.into());
}
g.slice_mut(s![offsets[b]..offsets[b + 1]]).assign(gb);
}
if !g.iter().all(|v| v.is_finite()) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: "direct parametric-AFT MLE: non-finite gradient".to_string(),
}
.into());
}
if let Some(kernel_g) = self.exact_newton_joint_loglik_gradient(&states)? {
if kernel_g.len() != p_total {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"direct parametric-AFT MLE: kernel gradient length {} != p_total {}",
kernel_g.len(),
p_total
),
}
.into());
}
if !kernel_g.iter().all(|v| v.is_finite()) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: "direct parametric-AFT MLE: non-finite kernel gradient".to_string(),
}
.into());
}
g = kernel_g;
}
let grad_norm = g.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
if grad_norm <= grad_tol {
break;
}
let h = self.exact_newton_joint_hessian(&states)?.ok_or_else(|| {
SurvivalLocationScaleError::NumericalFailure {
reason: "direct parametric-AFT MLE: joint Hessian assembly failed".to_string(),
}
})?;
if !h.iter().all(|v| v.is_finite()) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: "direct parametric-AFT MLE: non-finite joint Hessian".to_string(),
}
.into());
}
let h_scale = h
.diag()
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()))
.max(1.0);
let mut tau = 0.0_f64;
let delta = loop {
let mut damped = h.clone();
if tau > 0.0 {
for i in 0..p_total {
damped[[i, i]] += tau;
}
}
match damped.cholesky(faer::Side::Lower) {
Ok(chol) => break chol.solvevec(&g),
Err(_) => {
tau = if tau == 0.0 {
LEVENBERG_INITIAL_DAMPING_REL * h_scale
} else {
tau * LEVENBERG_DAMPING_GROWTH
};
if tau > LEVENBERG_MAX_DAMPING_REL * h_scale {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason:
"direct parametric-AFT MLE: Hessian not factorizable even with maximal damping"
.to_string(),
}
.into());
}
}
}
};
if !delta.iter().all(|v| v.is_finite()) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: "direct parametric-AFT MLE: non-finite Newton step".to_string(),
}
.into());
}
let mut alpha = 1.0_f64;
for (b, spec_offset) in offsets.iter().take(specs.len()).enumerate() {
let block_delta = delta.slice(s![*spec_offset..offsets[b + 1]]).to_owned();
if let Some(a_max) = self.max_feasible_step_size(&states, b, &block_delta)? {
alpha = alpha.min(a_max);
}
}
let directional = g.dot(&delta);
const ARMIJO_C: f64 = 1e-4;
const BACKTRACK: f64 = 0.5;
const MIN_ALPHA: f64 = 1e-12;
let mut accepted: Option<(Array1<f64>, Vec<ParameterBlockState>, f64)> = None;
while alpha >= MIN_ALPHA {
let trial_theta = &theta + &(alpha * &delta);
if let Ok(cand_states) = self.parametric_aft_states_from_theta(&trial_theta, specs)
&& let Ok(cand_ll) = self.log_likelihood_only(&cand_states)
&& cand_ll.is_finite()
&& cand_ll >= ll + ARMIJO_C * alpha * directional
{
accepted = Some((trial_theta, cand_states, cand_ll));
break;
}
alpha *= BACKTRACK;
}
match accepted {
Some((new_theta, new_states, new_ll)) => {
theta = new_theta;
states = new_states;
ll = new_ll;
}
None => break,
}
}
let h_final = self.exact_newton_joint_hessian(&states)?.ok_or_else(|| {
SurvivalLocationScaleError::NumericalFailure {
reason: "direct parametric-AFT MLE: final joint Hessian assembly failed"
.to_string(),
}
})?;
if !h_final.iter().all(|v| v.is_finite()) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: "direct parametric-AFT MLE: non-finite final joint Hessian".to_string(),
}
.into());
}
Ok((states, ll, h_final))
}
pub(crate) fn assemble_block_diagonal_hessians_from_quantities(
&self,
q: &SurvivalJointQuantities,
block_states: &[ParameterBlockState],
) -> Result<Vec<Array2<f64>>, String> {
let dynamic = self.build_dynamic_geometry(block_states)?;
let x_threshold_exit_cow = self.x_threshold.to_dense_cow();
let x_threshold_exit = &*x_threshold_exit_cow;
let x_threshold_entry_cow = self
.x_threshold_entry
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_threshold_entry = x_threshold_entry_cow
.as_ref()
.map_or(x_threshold_exit, |c| &**c);
let x_threshold_deriv_cow = self
.x_threshold_deriv
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_threshold_deriv = x_threshold_deriv_cow.as_deref();
let x_log_sigma_exit_cow = self.x_log_sigma.to_dense_cow();
let x_log_sigma_exit = &*x_log_sigma_exit_cow;
let x_log_sigma_entry_cow = self
.x_log_sigma_entry
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_log_sigma_entry = x_log_sigma_entry_cow
.as_ref()
.map_or(x_log_sigma_exit, |c| &**c);
let x_log_sigma_deriv_cow = self
.x_log_sigma_deriv
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_log_sigma_deriv = x_log_sigma_deriv_cow.as_deref();
let use_outer_parallel = rayon::current_num_threads() > 1;
let product_parallelism = if use_outer_parallel {
faer::Par::Seq
} else {
faer::get_global_parallelism()
};
let assemble_h_time = || -> Result<Array2<f64>, String> {
let nll = q.time_channel_nll_curvatures();
Ok(safe_fast_xt_diag_x_with_parallelism(
&dynamic.time_jac_entry,
&nll.h0,
product_parallelism,
) + safe_fast_xt_diag_x_with_parallelism(
&dynamic.time_jac_exit,
&nll.h1,
product_parallelism,
) + safe_fast_xt_diag_x_with_parallelism(
&dynamic.time_jac_deriv,
&nll.d,
product_parallelism,
))
};
let assemble_h_tt = || -> Result<Array2<f64>, String> {
if let Some(x_t_deriv) = x_threshold_deriv {
let h_exit = -(&q.d2_q1 * &q.dq_t.mapv(|v| safe_product(v, v))
+ &q.d2_qdot1 * &q.dqdot_t.mapv(|v| safe_product(v, v))
+ &q.d1_qdot1 * &q.d2qdot_tt);
let h_entry =
-(&q.d2_q0 * &q.dq_t_entry.as_ref().unwrap().mapv(|v| safe_product(v, v)));
let h_deriv = -(&q.d2_qdot1 * &q.dqdot_td.mapv(|v| safe_product(v, v)));
let h_exit_deriv =
-(&q.d2_qdot1 * &(&q.dqdot_t * &q.dqdot_td) + &q.d1_qdot1 * &q.d2qdot_ttd);
let mut h_tt = weighted_crossprod_dense_with_parallelism(
x_threshold_exit,
&h_exit,
x_threshold_exit,
product_parallelism,
)? + weighted_crossprod_dense_with_parallelism(
x_threshold_entry,
&h_entry,
x_threshold_entry,
product_parallelism,
)? + weighted_crossprod_dense_with_parallelism(
x_t_deriv,
&h_deriv,
x_t_deriv,
product_parallelism,
)?;
let cross = weighted_crossprod_dense_with_parallelism(
x_threshold_exit,
&h_exit_deriv,
x_t_deriv,
product_parallelism,
)?;
h_tt += ✗
h_tt += &cross.t().to_owned();
Ok(h_tt)
} else {
let h_t = -(&q.d2_q1 * &q.dq_t.mapv(|v| safe_product(v, v))
+ &q.d2_q0 * &q.dq_t_entry.as_ref().unwrap().mapv(|v| safe_product(v, v))
+ &q.d2_qdot1 * &q.dqdot_t.mapv(|v| safe_product(v, v))
+ &q.d1_qdot1 * &q.d2qdot_tt);
weighted_crossprod_dense_with_parallelism(
x_threshold_exit,
&h_t,
x_threshold_exit,
product_parallelism,
)
}
};
let assemble_h_ll = || -> Result<Array2<f64>, String> {
if let Some(x_ls_deriv) = x_log_sigma_deriv {
let dq_ls_entry = q.dq_ls_entry.as_ref().unwrap();
let d2q_ls_entry = q.d2q_ls_entry.as_ref().unwrap();
let h_exit = -(&q.d2_q1 * &q.dq_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_q1 * &q.d2q_ls)
+ &q.d2_qdot1 * &q.dqdot_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_qdot1 * &q.d2qdot_ls));
let h_entry = -(&q.d2_q0 * &dq_ls_entry.mapv(|v| safe_product(v, v))
+ &(&q.d1_q0 * d2q_ls_entry));
let h_deriv = -(&q.d2_qdot1 * &q.dqdot_lsd.mapv(|v| safe_product(v, v)));
let h_exit_deriv =
-(&q.d2_qdot1 * &(&q.dqdot_ls * &q.dqdot_lsd) + &q.d1_qdot1 * &q.d2qdot_lslsd);
let mut h_ll = weighted_crossprod_dense_with_parallelism(
x_log_sigma_exit,
&h_exit,
x_log_sigma_exit,
product_parallelism,
)? + weighted_crossprod_dense_with_parallelism(
x_log_sigma_entry,
&h_entry,
x_log_sigma_entry,
product_parallelism,
)? + weighted_crossprod_dense_with_parallelism(
x_ls_deriv,
&h_deriv,
x_ls_deriv,
product_parallelism,
)?;
let cross = weighted_crossprod_dense_with_parallelism(
x_log_sigma_exit,
&h_exit_deriv,
x_ls_deriv,
product_parallelism,
)?;
h_ll += ✗
h_ll += &cross.t().to_owned();
Ok(h_ll)
} else {
let h_ls = -(&q.d2_q1 * &q.dq_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_q1 * &q.d2q_ls)
+ &q.d2_q0 * &q.dq_ls_entry.as_ref().unwrap().mapv(|v| safe_product(v, v))
+ &(&q.d1_q0 * q.d2q_ls_entry.as_ref().unwrap())
+ &q.d2_qdot1 * &q.dqdot_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_qdot1 * &q.d2qdot_ls));
weighted_crossprod_dense_with_parallelism(
x_log_sigma_exit,
&h_ls,
x_log_sigma_exit,
product_parallelism,
)
}
};
let assemble_h_wiggle = || -> Result<Option<Array2<f64>>, String> {
if let (Some(xw_exit), Some(xw_entry), Some(xw_qdot)) = (
dynamic.wiggle_basis_exit.as_ref(),
dynamic.wiggle_basis_entry.as_ref(),
dynamic.wiggle_qdot_basis_exit.as_ref(),
) {
Ok(Some(
weighted_crossprod_dense_with_parallelism(
xw_exit,
&(-&q.d2_q1),
xw_exit,
product_parallelism,
)? + weighted_crossprod_dense_with_parallelism(
xw_entry,
&(-&q.d2_q0),
xw_entry,
product_parallelism,
)? + weighted_crossprod_dense_with_parallelism(
xw_qdot,
&(-&q.d2_qdot1),
xw_qdot,
product_parallelism,
)?,
))
} else {
Ok(None)
}
};
let (h_time, h_tt, h_ll, h_wiggle) = if use_outer_parallel {
let ((h_time, h_tt), (h_ll, h_wiggle)) = rayon::join(
|| rayon::join(assemble_h_time, assemble_h_tt),
|| rayon::join(assemble_h_ll, assemble_h_wiggle),
);
(h_time?, h_tt?, h_ll?, h_wiggle?)
} else {
(
assemble_h_time()?,
assemble_h_tt()?,
assemble_h_ll()?,
assemble_h_wiggle()?,
)
};
let mut blocks = vec![h_time, h_tt, h_ll];
if let Some(hww) = h_wiggle {
blocks.push(hww);
}
Ok(blocks)
}
pub(crate) fn assemble_joint_hessian_from_quantities(
&self,
q: &SurvivalJointQuantities,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
self.assemble_joint_hessian_from_quantities_masked(q, block_states, None)
}
pub(crate) fn assemble_joint_hessian_from_quantities_masked(
&self,
q: &SurvivalJointQuantities,
block_states: &[ParameterBlockState],
row_mask: Option<&Array1<f64>>,
) -> Result<Option<Array2<f64>>, String> {
let dynamic = self.build_dynamic_geometry(block_states)?;
let joint_states = self.validate_joint_states(block_states)?;
let eta_t_exit = joint_states.3;
let eta_t_entry = joint_states.5;
let eta_t_deriv_exit = joint_states.7;
let eta_ls_deriv_exit = joint_states.8;
let eta_t_deriv_exit = eta_t_deriv_exit
.map(|v| v.to_owned())
.unwrap_or_else(|| Array1::zeros(self.n));
let eta_ls_deriv_exit = eta_ls_deriv_exit
.map(|v| v.to_owned())
.unwrap_or_else(|| Array1::zeros(self.n));
let offsets = self.joint_block_offsets();
let p_total = *offsets
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
let x_threshold_exit_cow = self.x_threshold.to_dense_cow();
let x_threshold_exit = &*x_threshold_exit_cow;
let x_threshold_entry_cow = self
.x_threshold_entry
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_threshold_entry = x_threshold_entry_cow
.as_ref()
.map_or(x_threshold_exit, |c| &**c);
let x_threshold_deriv_cow = self
.x_threshold_deriv
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_threshold_deriv = x_threshold_deriv_cow.as_deref();
let x_log_sigma_exit_cow = self.x_log_sigma.to_dense_cow();
let x_log_sigma_exit = &*x_log_sigma_exit_cow;
let x_log_sigma_entry_cow = self
.x_log_sigma_entry
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_log_sigma_entry = x_log_sigma_entry_cow
.as_ref()
.map_or(x_log_sigma_exit, |c| &**c);
let x_log_sigma_deriv_cow = self
.x_log_sigma_deriv
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_log_sigma_deriv = x_log_sigma_deriv_cow.as_deref();
let mut joint = Array2::<f64>::zeros((p_total, p_total));
let add_cross = |acc: &mut Array2<f64>,
left: &Array2<f64>,
weights: &Array1<f64>,
right: &Array2<f64>|
-> Result<(), String> {
*acc += &mxtwx(left, weights, right, row_mask)?;
Ok(())
};
let nll_time = q.time_channel_nll_curvatures();
let h_time = mxtwxd(&dynamic.time_jac_entry, &nll_time.h0, row_mask)
+ mxtwxd(&dynamic.time_jac_exit, &nll_time.h1, row_mask)
+ mxtwxd(&dynamic.time_jac_deriv, &nll_time.d, row_mask);
assign_symmetric_block(&mut joint, offsets[0], offsets[0], &h_time);
if let Some(x_t_deriv) = x_threshold_deriv {
let h_exit = -(&q.d2_q1 * &q.dq_t.mapv(|v| safe_product(v, v))
+ &q.d2_qdot1 * &q.dqdot_t.mapv(|v| safe_product(v, v))
+ &q.d1_qdot1 * &q.d2qdot_tt);
let h_entry =
-(&q.d2_q0 * &q.dq_t_entry.as_ref().unwrap().mapv(|v| safe_product(v, v)));
let h_deriv = -(&q.d2_qdot1 * &q.dqdot_td.mapv(|v| safe_product(v, v)));
let h_exit_deriv =
-(&q.d2_qdot1 * &(&q.dqdot_t * &q.dqdot_td) + &q.d1_qdot1 * &q.d2qdot_ttd);
let mut h_tt = mxtwx(x_threshold_exit, &h_exit, x_threshold_exit, row_mask)?
+ mxtwx(x_threshold_entry, &h_entry, x_threshold_entry, row_mask)?
+ mxtwx(x_t_deriv, &h_deriv, x_t_deriv, row_mask)?;
let cross = mxtwx(x_threshold_exit, &h_exit_deriv, x_t_deriv, row_mask)?;
h_tt += ✗
h_tt += &cross.t().to_owned();
assign_symmetric_block(&mut joint, offsets[1], offsets[1], &h_tt);
} else {
let h_t = -(&q.d2_q1 * &q.dq_t.mapv(|v| safe_product(v, v))
+ &q.d2_q0 * &q.dq_t_entry.as_ref().unwrap().mapv(|v| safe_product(v, v))
+ &q.d2_qdot1 * &q.dqdot_t.mapv(|v| safe_product(v, v))
+ &q.d1_qdot1 * &q.d2qdot_tt);
let h_tt = mxtwx(x_threshold_exit, &h_t, x_threshold_exit, row_mask)?;
assign_symmetric_block(&mut joint, offsets[1], offsets[1], &h_tt);
}
if let Some(x_ls_deriv) = x_log_sigma_deriv {
let dq_ls_entry = q.dq_ls_entry.as_ref().unwrap();
let d2q_ls_entry = q.d2q_ls_entry.as_ref().unwrap();
let h_exit = -(&q.d2_q1 * &q.dq_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_q1 * &q.d2q_ls)
+ &q.d2_qdot1 * &q.dqdot_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_qdot1 * &q.d2qdot_ls));
let h_entry = -(&q.d2_q0 * &dq_ls_entry.mapv(|v| safe_product(v, v))
+ &(&q.d1_q0 * d2q_ls_entry));
let h_deriv = -(&q.d2_qdot1 * &q.dqdot_lsd.mapv(|v| safe_product(v, v)));
let h_exit_deriv =
-(&q.d2_qdot1 * &(&q.dqdot_ls * &q.dqdot_lsd) + &q.d1_qdot1 * &q.d2qdot_lslsd);
let mut h_ll = mxtwx(x_log_sigma_exit, &h_exit, x_log_sigma_exit, row_mask)?
+ mxtwx(x_log_sigma_entry, &h_entry, x_log_sigma_entry, row_mask)?
+ mxtwx(x_ls_deriv, &h_deriv, x_ls_deriv, row_mask)?;
let cross = mxtwx(x_log_sigma_exit, &h_exit_deriv, x_ls_deriv, row_mask)?;
h_ll += ✗
h_ll += &cross.t().to_owned();
assign_symmetric_block(&mut joint, offsets[2], offsets[2], &h_ll);
} else {
let h_ls = -(&q.d2_q1 * &q.dq_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_q1 * &q.d2q_ls)
+ &q.d2_q0 * &q.dq_ls_entry.as_ref().unwrap().mapv(|v| safe_product(v, v))
+ &(&q.d1_q0 * q.d2q_ls_entry.as_ref().unwrap())
+ &q.d2_qdot1 * &q.dqdot_ls.mapv(|v| safe_product(v, v))
+ &(&q.d1_qdot1 * &q.d2qdot_ls));
let h_ll = mxtwx(x_log_sigma_exit, &h_ls, x_log_sigma_exit, row_mask)?;
assign_symmetric_block(&mut joint, offsets[2], offsets[2], &h_ll);
}
{
let mut h_tl = Array2::<f64>::zeros((offsets[2] - offsets[1], offsets[3] - offsets[2]));
let w_exit = -(&q.d2_q1 * &(&q.dq_t * &q.dq_ls) + &(&q.d1_q1 * &q.d2q_tls));
let w_entry = -(&q.d2_q0
* &(q.dq_t_entry.as_ref().unwrap() * q.dq_ls_entry.as_ref().unwrap())
+ &(&q.d1_q0 * q.d2q_tls_entry.as_ref().unwrap()));
add_cross(&mut h_tl, x_threshold_exit, &w_exit, x_log_sigma_exit)?;
add_cross(&mut h_tl, x_threshold_entry, &w_entry, x_log_sigma_entry)?;
let w_qdot_exit =
-(&q.d2_qdot1 * &(&q.dqdot_t * &q.dqdot_ls) + &(&q.d1_qdot1 * &q.d2qdot_tls));
add_cross(&mut h_tl, x_threshold_exit, &w_qdot_exit, x_log_sigma_exit)?;
if let Some(x_ls_deriv) = x_log_sigma_deriv {
let w =
-(&q.d2_qdot1 * &(&q.dqdot_t * &q.dqdot_lsd) + &(&q.d1_qdot1 * &q.d2qdot_tlsd));
add_cross(&mut h_tl, x_threshold_exit, &w, x_ls_deriv)?;
}
if let Some(x_t_deriv) = x_threshold_deriv {
let w =
-(&q.d2_qdot1 * &(&q.dqdot_td * &q.dqdot_ls) + &(&q.d1_qdot1 * &q.d2qdot_lstd));
add_cross(&mut h_tl, x_t_deriv, &w, x_log_sigma_exit)?;
if let Some(x_ls_deriv) = x_log_sigma_deriv {
let wdd = -(&q.d2_qdot1 * &(&q.dqdot_td * &q.dqdot_lsd));
add_cross(&mut h_tl, x_t_deriv, &wdd, x_ls_deriv)?;
}
}
assign_symmetric_block(&mut joint, offsets[1], offsets[2], &h_tl);
}
let mut h_ht = mxtwx(
&self.x_time_entry,
&(&nll_time.h0 * q.dq_t_entry.as_ref().unwrap()),
x_threshold_entry,
row_mask,
)? + mxtwx(
&self.x_time_exit,
&(&nll_time.h1 * &q.dq_t),
x_threshold_exit,
row_mask,
)? + mxtwx(
&self.x_time_deriv,
&(&nll_time.d * &q.dqdot_t),
x_threshold_exit,
row_mask,
)?;
if let Some(x_t_deriv) = x_threshold_deriv {
h_ht += &mxtwx(
&self.x_time_deriv,
&(&nll_time.d * &q.dqdot_td),
x_t_deriv,
row_mask,
)?;
}
assign_symmetric_block(&mut joint, offsets[0], offsets[1], &h_ht);
let mut h_hl = mxtwx(
&self.x_time_entry,
&(&nll_time.h0 * q.dq_ls_entry.as_ref().unwrap()),
x_log_sigma_entry,
row_mask,
)? + mxtwx(
&self.x_time_exit,
&(&nll_time.h1 * &q.dq_ls),
x_log_sigma_exit,
row_mask,
)? + mxtwx(
&self.x_time_deriv,
&(&nll_time.d * &q.dqdot_ls),
x_log_sigma_exit,
row_mask,
)?;
if let Some(x_ls_deriv) = x_log_sigma_deriv {
h_hl += &mxtwx(
&self.x_time_deriv,
&(&nll_time.d * &q.dqdot_lsd),
x_ls_deriv,
row_mask,
)?;
}
assign_symmetric_block(&mut joint, offsets[0], offsets[2], &h_hl);
if let (
Some(xw_exit),
Some(xw_entry),
Some(xw_qdot),
Some(xw_d1_exit),
Some(xw_d1_entry),
Some(xw_d2_exit),
Some(w_offset),
) = (
dynamic.wiggle_basis_exit.as_ref(),
dynamic.wiggle_basis_entry.as_ref(),
dynamic.wiggle_qdot_basis_exit.as_ref(),
dynamic.wiggle_basis_d1_exit.as_ref(),
dynamic.wiggle_basis_d1_entry.as_ref(),
dynamic.wiggle_basis_d2_exit.as_ref(),
offsets.get(3).copied(),
) {
let hww = mxtwx(xw_exit, &(-&q.d2_q1), xw_exit, row_mask)?
+ mxtwx(xw_entry, &(-&q.d2_q0), xw_entry, row_mask)?
+ mxtwx(xw_qdot, &(-&q.d2_qdot1), xw_qdot, row_mask)?;
assign_symmetric_block(&mut joint, w_offset, w_offset, &hww);
let q0_t_entry = Array1::from_iter(dynamic.inv_sigma_entry.iter().map(|&r| -r));
let q0_t_exit = Array1::from_iter(dynamic.inv_sigma_exit.iter().map(|&r| -r));
let q0_ls_entry = Array1::from_iter(
(0..self.n)
.map(|i| q_chain_derivs_scalar(eta_t_entry[i], dynamic.eta_ls_entry[i]).1),
);
let q0_ls_exit = Array1::from_iter(
(0..self.n).map(|i| q_chain_derivs_scalar(eta_t_exit[i], dynamic.eta_ls_exit[i]).1),
);
let r_base_exit = safe_linear_combo2_arrays(
&q0_t_exit,
&eta_t_deriv_exit,
&q0_ls_exit,
&eta_ls_deriv_exit,
)?;
let r_t_base_exit = Array1::from_iter((0..self.n).map(|i| {
safe_product(
q_chain_derivs_scalar(eta_t_exit[i], dynamic.eta_ls_exit[i]).2,
eta_ls_deriv_exit[i],
)
}));
let r_ls_base_exit = Array1::from_iter((0..self.n).map(|i| {
let (_, _, q_tl, q_ll, _, _) =
q_chain_derivs_scalar(eta_t_exit[i], dynamic.eta_ls_exit[i]);
safe_sum2(
safe_product(q_tl, eta_t_deriv_exit[i]),
safe_product(q_ll, eta_ls_deriv_exit[i]),
)
}));
let tw_entry_d2 = scale_dense_rows(xw_d1_entry, &q0_t_entry)?;
let tw_exit_d2 = scale_dense_rows(xw_d1_exit, &q0_t_exit)?;
let lw_entry_d2 = scale_dense_rows(xw_d1_entry, &q0_ls_entry)?;
let lw_exit_d2 = scale_dense_rows(xw_d1_exit, &q0_ls_exit)?;
let qdot_t_w = scale_dense_rows(
xw_d2_exit,
&safe_hadamard_product(&q0_t_exit, &r_base_exit)?,
)? + scale_dense_rows(xw_d1_exit, &r_t_base_exit)?;
let qdot_ls_w = scale_dense_rows(
xw_d2_exit,
&safe_hadamard_product(&q0_ls_exit, &r_base_exit)?,
)? + scale_dense_rows(xw_d1_exit, &r_ls_base_exit)?;
let qdot_td_w = scale_dense_rows(xw_d1_exit, &q0_t_exit)?;
let qdot_lsd_w = scale_dense_rows(xw_d1_exit, &q0_ls_exit)?;
let mut h_tw = Array2::<f64>::zeros((offsets[2] - offsets[1], offsets[4] - offsets[3]));
h_tw += &mxtwx(x_threshold_exit, &(-&q.d2_q1 * &q.dq_t), xw_exit, row_mask)?;
h_tw += &mxtwx(
x_threshold_exit,
&(-&q.d1_q1 * &q0_t_exit),
&tw_exit_d2,
row_mask,
)?;
h_tw += &mxtwx(
x_threshold_entry,
&(-&q.d2_q0 * q.dq_t_entry.as_ref().unwrap()),
xw_entry,
row_mask,
)?;
h_tw += &mxtwx(
x_threshold_entry,
&(-&q.d1_q0 * &q0_t_entry),
&tw_entry_d2,
row_mask,
)?;
h_tw += &mxtwx(
x_threshold_exit,
&(-&q.d2_qdot1 * &q.dqdot_t),
xw_qdot,
row_mask,
)?;
h_tw += &mxtwx(x_threshold_exit, &(-&q.d1_qdot1), &qdot_t_w, row_mask)?;
if let Some(x_t_deriv) = x_threshold_deriv {
h_tw += &mxtwx(x_t_deriv, &(-&q.d2_qdot1 * &q.dqdot_td), xw_qdot, row_mask)?;
h_tw += &mxtwx(x_t_deriv, &(-&q.d1_qdot1), &qdot_td_w, row_mask)?;
}
assign_symmetric_block(&mut joint, offsets[1], w_offset, &h_tw);
let mut h_lw = Array2::<f64>::zeros((offsets[3] - offsets[2], offsets[4] - offsets[3]));
h_lw += &mxtwx(x_log_sigma_exit, &(-&q.d2_q1 * &q.dq_ls), xw_exit, row_mask)?;
h_lw += &mxtwx(
x_log_sigma_exit,
&(-(&q.d1_q1 * &q0_ls_exit)),
&lw_exit_d2,
row_mask,
)?;
h_lw += &mxtwx(
x_log_sigma_entry,
&(-&q.d2_q0 * q.dq_ls_entry.as_ref().unwrap()),
xw_entry,
row_mask,
)?;
h_lw += &mxtwx(
x_log_sigma_entry,
&(-(&q.d1_q0 * &q0_ls_entry)),
&lw_entry_d2,
row_mask,
)?;
h_lw += &mxtwx(
x_log_sigma_exit,
&(-&q.d2_qdot1 * &q.dqdot_ls),
xw_qdot,
row_mask,
)?;
h_lw += &mxtwx(x_log_sigma_exit, &(-&q.d1_qdot1), &qdot_ls_w, row_mask)?;
if let Some(x_ls_deriv) = x_log_sigma_deriv {
h_lw += &mxtwx(
x_ls_deriv,
&(-&q.d2_qdot1 * &q.dqdot_lsd),
xw_qdot,
row_mask,
)?;
h_lw += &mxtwx(x_ls_deriv, &(-&q.d1_qdot1), &qdot_lsd_w, row_mask)?;
}
assign_symmetric_block(&mut joint, offsets[2], w_offset, &h_lw);
let h_hw = mxtwx(&self.x_time_entry, &nll_time.h0, xw_entry, row_mask)?
+ mxtwx(&self.x_time_exit, &nll_time.h1, xw_exit, row_mask)?
+ mxtwx(&self.x_time_deriv, &nll_time.d, xw_qdot, row_mask)?;
assign_symmetric_block(&mut joint, offsets[0], w_offset, &h_hw);
}
Ok(Some(joint))
}
pub(crate) fn hessian_deriv_log_rescale(&self, block_states: &[ParameterBlockState]) -> f64 {
if !matches!(
self.inverse_link,
InverseLink::Standard(StandardLink::CLogLog)
) {
return 0.0;
}
let dynamic = match self.build_dynamic_geometry(block_states) {
Ok(d) => d,
Err(_) => return 0.0,
};
let mut max_u = f64::NEG_INFINITY;
for i in 0..self.n {
if self.w[i] <= 0.0 {
continue;
}
let u0 = dynamic.h_entry[i] + dynamic.q_entry[i];
let u1 = dynamic.h_exit[i] + dynamic.q_exit[i];
max_u = max_u.max(u0).max(u1);
}
(max_u - 500.0).max(0.0)
}
pub(crate) fn exact_newton_joint_hessian_rescaled(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<(Array2<f64>, f64)>, String> {
let log_scale = self.hessian_deriv_log_rescale(block_states);
if log_scale == 0.0 {
return Ok(self
.exact_newton_joint_hessian(block_states)?
.map(|h| (h, 0.0)));
}
let q = self.collect_joint_quantities_rescaled(block_states, log_scale)?;
if self.x_link_wiggle.is_some() {
let dynamic = self.build_dynamic_geometry(block_states)?;
return Ok(Some((
super::row_kernel::survival_ls_wiggle_joint_hessian_dense(
self, &q, &dynamic, log_scale,
)?,
log_scale,
)));
}
if self.row_kernel_joint_hessian_supported() {
let dynamic = self.build_dynamic_geometry(block_states)?;
let kernel = self.survival_ls_row_kernel_rescaled(&q, &dynamic, log_scale);
let rows = crate::row_kernel::RowSet::All;
let cache = crate::row_kernel::build_row_kernel_cache(&kernel, &rows)?;
return Ok(Some((
crate::row_kernel::row_kernel_hessian_dense(&kernel, &cache, &rows),
log_scale,
)));
}
Ok(self
.assemble_joint_hessian_from_quantities(&q, block_states)?
.map(|h| (h, log_scale)))
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_rescaled(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
log_rescale: f64,
) -> Result<Option<Array2<f64>>, String> {
let q = self.collect_joint_quantities_rescaled(block_states, log_rescale)?;
let dynamic = self.build_dynamic_geometry(block_states)?;
self.exact_newton_joint_hessian_directional_derivative_rescaled_from_parts(
d_beta_flat,
&q,
&dynamic,
log_rescale,
)
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_rescaled_from_parts(
&self,
d_beta_flat: &Array1<f64>,
q: &SurvivalJointQuantities,
dynamic: &SurvivalDynamicGeometry,
deriv_log_scale: f64,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_hessian_directional_derivative_rescaled_from_parts_masked(
d_beta_flat,
q,
dynamic,
deriv_log_scale,
None,
)
}
pub(crate) fn exact_newton_joint_hessian_directional_derivative_rescaled_from_parts_masked(
&self,
d_beta_flat: &Array1<f64>,
q: &SurvivalJointQuantities,
dynamic: &SurvivalDynamicGeometry,
deriv_log_scale: f64,
row_mask: Option<&Array1<f64>>,
) -> Result<Option<Array2<f64>>, String> {
let offsets = self.joint_block_offsets();
let p_total = *offsets
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
if d_beta_flat.len() != p_total {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"joint d_beta length mismatch: got {}, expected {p_total}",
d_beta_flat.len()
),
}
.into());
}
if self.row_kernel_directional_supported() {
let kernel = self.survival_ls_row_kernel_rescaled(q, dynamic, deriv_log_scale);
let rows = row_set_from_survival_mask(row_mask, self.n);
return crate::row_kernel::row_kernel_directional_derivative(
&kernel,
&rows,
d_beta_flat
.as_slice()
.ok_or_else(|| "joint d_beta must be contiguous".to_string())?,
)
.map(Some);
}
let rows = row_set_from_survival_mask(row_mask, self.n);
let d = d_beta_flat
.as_slice()
.ok_or_else(|| "joint d_beta must be contiguous".to_string())?;
Ok(Some(
super::row_kernel::survival_ls_wiggle_directional_derivative_dense(
self,
q,
dynamic,
deriv_log_scale,
&rows,
d,
)?,
))
}
pub(crate) fn evaluate_log_likelihood_and_block_gradients(
&self,
block_states: &[ParameterBlockState],
) -> Result<(f64, Vec<Array1<f64>>), String> {
self.evaluate_log_likelihood_and_block_gradients_masked(block_states, None)
}
pub(crate) fn evaluate_log_likelihood_and_block_gradients_masked(
&self,
block_states: &[ParameterBlockState],
row_mask: Option<&Array1<f64>>,
) -> Result<(f64, Vec<Array1<f64>>), String> {
let n = self.n;
let dynamic = self.build_dynamic_geometry(block_states)?;
let mut ll = 0.0;
let mut grad_time_eta_h0 = Array1::<f64>::zeros(n);
let mut grad_time_eta_h1 = Array1::<f64>::zeros(n);
let mut grad_time_eta_d = Array1::<f64>::zeros(n);
let mut d1_q0 = Array1::<f64>::zeros(n);
let mut d1_q1 = Array1::<f64>::zeros(n);
let mut d1_qdot = Array1::<f64>::zeros(n);
let mask_at = |i: usize| -> f64 { row_mask.map_or(1.0, |m| m[i]) };
if n >= Self::EVALUATE_PARALLEL_ROW_THRESHOLD && rayon::current_num_threads() > 1 {
const CHUNK: usize = 1024;
let d1_q0_s = d1_q0
.as_slice_memory_order_mut()
.expect("zeros is contiguous");
let d1_q1_s = d1_q1
.as_slice_memory_order_mut()
.expect("zeros is contiguous");
let d1_qdot_s = d1_qdot
.as_slice_memory_order_mut()
.expect("zeros is contiguous");
let g_h0_s = grad_time_eta_h0
.as_slice_memory_order_mut()
.expect("zeros is contiguous");
let g_h1_s = grad_time_eta_h1
.as_slice_memory_order_mut()
.expect("zeros is contiguous");
let g_d_s = grad_time_eta_d
.as_slice_memory_order_mut()
.expect("zeros is contiguous");
ll = d1_q0_s
.par_chunks_mut(CHUNK)
.zip(d1_q1_s.par_chunks_mut(CHUNK))
.zip(d1_qdot_s.par_chunks_mut(CHUNK))
.zip(g_h0_s.par_chunks_mut(CHUNK))
.zip(g_h1_s.par_chunks_mut(CHUNK))
.zip(g_d_s.par_chunks_mut(CHUNK))
.enumerate()
.try_fold(
|| 0.0_f64,
|local_ll,
(chunk_idx, (((((d1q0_c, d1q1_c), d1qd_c), gh0_c), gh1_c), gd_c))|
-> Result<f64, String> {
let start = chunk_idx * CHUNK;
let mut acc = local_ll;
for local in 0..d1q0_c.len() {
let i = start + local;
let state = self.row_predictor_state(
dynamic.h_entry[i],
dynamic.h_exit[i],
dynamic.hdot_exit[i],
dynamic.q_entry[i],
dynamic.q_exit[i],
dynamic.qdot_exit[i],
);
if let Some(row) = self.row_derivatives(i, state)? {
let w = mask_at(i);
acc += row.ll * w;
d1q0_c[local] = row.d1_q0 * w;
d1q1_c[local] = row.d1_q1 * w;
d1qd_c[local] = row.d1_qdot1 * w;
gh0_c[local] = row.grad_time_eta_h0 * w;
gh1_c[local] = row.grad_time_eta_h1 * w;
gd_c[local] = row.grad_time_eta_d * w;
}
}
Ok(acc)
},
)
.try_reduce(|| 0.0_f64, |a, b| Ok::<_, String>(a + b))?;
} else {
for i in 0..n {
let state = self.row_predictor_state(
dynamic.h_entry[i],
dynamic.h_exit[i],
dynamic.hdot_exit[i],
dynamic.q_entry[i],
dynamic.q_exit[i],
dynamic.qdot_exit[i],
);
let Some(row) = self.row_derivatives(i, state)? else {
continue;
};
let w = mask_at(i);
ll += row.ll * w;
d1_q0[i] = row.d1_q0 * w;
d1_q1[i] = row.d1_q1 * w;
d1_qdot[i] = row.d1_qdot1 * w;
grad_time_eta_h0[i] = row.grad_time_eta_h0 * w;
grad_time_eta_h1[i] = row.grad_time_eta_h1 * w;
grad_time_eta_d[i] = row.grad_time_eta_d * w;
}
}
let grad_time = dynamic.time_jac_entry.t().dot(&grad_time_eta_h0)
+ dynamic.time_jac_exit.t().dot(&grad_time_eta_h1)
+ dynamic.time_jac_deriv.t().dot(&grad_time_eta_d);
let mut scratch = Array1::<f64>::zeros(n);
let grad_t = if let (Some(x_t_entry), Some(x_t_deriv)) = (
self.x_threshold_entry.as_ref(),
self.x_threshold_deriv.as_ref(),
) {
ndarray::Zip::from(&mut scratch)
.and(&d1_q1)
.and(&dynamic.dq_t_exit)
.and(&d1_qdot)
.and(&dynamic.dqdot_t)
.for_each(|s, &a, &b, &c, &d| *s = a * b + c * d);
let mut out = self.x_threshold.transpose_vector_multiply(&scratch);
ndarray::Zip::from(&mut scratch)
.and(&d1_q0)
.and(&dynamic.dq_t_entry)
.for_each(|s, &a, &b| *s = a * b);
out = out + x_t_entry.transpose_vector_multiply(&scratch);
ndarray::Zip::from(&mut scratch)
.and(&d1_qdot)
.and(&dynamic.dqdot_td)
.for_each(|s, &a, &b| *s = a * b);
out + x_t_deriv.transpose_vector_multiply(&scratch)
} else {
ndarray::Zip::from(&mut scratch)
.and(&d1_q1)
.and(&dynamic.dq_t_exit)
.and(&d1_q0)
.and(&dynamic.dq_t_entry)
.for_each(|s, &a, &b, &c, &d| *s = a * b + c * d);
ndarray::Zip::from(&mut scratch)
.and(&d1_qdot)
.and(&dynamic.dqdot_t)
.for_each(|s, &a, &b| *s += a * b);
self.x_threshold.transpose_vector_multiply(&scratch)
};
let grad_ls = if let (Some(x_ls_entry), Some(x_ls_deriv)) = (
self.x_log_sigma_entry.as_ref(),
self.x_log_sigma_deriv.as_ref(),
) {
ndarray::Zip::from(&mut scratch)
.and(&d1_q1)
.and(&dynamic.dq_ls_exit)
.and(&d1_qdot)
.and(&dynamic.dqdot_ls)
.for_each(|s, &a, &b, &c, &d| *s = a * b + c * d);
let mut out = self.x_log_sigma.transpose_vector_multiply(&scratch);
ndarray::Zip::from(&mut scratch)
.and(&d1_q0)
.and(&dynamic.dq_ls_entry)
.for_each(|s, &a, &b| *s = a * b);
out = out + x_ls_entry.transpose_vector_multiply(&scratch);
ndarray::Zip::from(&mut scratch)
.and(&d1_qdot)
.and(&dynamic.dqdot_lsd)
.for_each(|s, &a, &b| *s = a * b);
out + x_ls_deriv.transpose_vector_multiply(&scratch)
} else {
ndarray::Zip::from(&mut scratch)
.and(&d1_q1)
.and(&dynamic.dq_ls_exit)
.and(&d1_q0)
.and(&dynamic.dq_ls_entry)
.for_each(|s, &a, &b, &c, &d| *s = a * b + c * d);
ndarray::Zip::from(&mut scratch)
.and(&d1_qdot)
.and(&dynamic.dqdot_ls)
.for_each(|s, &a, &b| *s += a * b);
self.x_log_sigma.transpose_vector_multiply(&scratch)
};
let mut block_gradients = vec![grad_time, grad_t, grad_ls];
if let (Some(xw_exit), Some(xw_entry), Some(xw_qdot)) = (
dynamic.wiggle_basis_exit.as_ref(),
dynamic.wiggle_basis_entry.as_ref(),
dynamic.wiggle_qdot_basis_exit.as_ref(),
) {
let gradw =
xw_exit.t().dot(&d1_q1) + xw_entry.t().dot(&d1_q0) + xw_qdot.t().dot(&d1_qdot);
block_gradients.push(gradw);
}
Ok((ll, block_gradients))
}
pub fn block_effective_jacobian(
specs: &[ParameterBlockSpec],
block_idx: usize,
) -> Result<Box<dyn BlockEffectiveJacobian>, String> {
crate::block_layout::block_jacobian::AdditiveWiggleBlockLayout {
family: "SurvivalLocationScaleFamily",
n_outputs: 3,
additive_blocks: &[
Self::BLOCK_TIME,
Self::BLOCK_THRESHOLD,
Self::BLOCK_LOG_SIGMA,
],
wiggle_block: Some(Self::BLOCK_LINK_WIGGLE),
}
.block_effective_jacobian(specs, block_idx)
}
}
pub struct SurvivalLocationScaleChannelHessian {
pub(crate) h: ndarray::Array3<f64>,
}
impl SurvivalLocationScaleChannelHessian {
pub const K: usize = 3;
pub fn from_full(h: ndarray::Array3<f64>) -> Self {
assert_eq!(
h.shape()[1],
Self::K,
"SurvivalLocationScaleChannelHessian: expected K={} channels, got {}",
Self::K,
h.shape()[1],
);
assert_eq!(
h.shape()[2],
Self::K,
"SurvivalLocationScaleChannelHessian: expected K={} channels, got {}",
Self::K,
h.shape()[2],
);
Self { h }
}
pub fn identity(n: usize) -> Self {
let mut h = ndarray::Array3::<f64>::zeros((n, Self::K, Self::K));
for i in 0..n {
for c in 0..Self::K {
h[[i, c, c]] = 1.0;
}
}
Self { h }
}
}
impl FamilyChannelHessian for SurvivalLocationScaleChannelHessian {
fn n_outputs(&self) -> usize {
Self::K
}
fn n_subjects(&self) -> usize {
self.h.shape()[0]
}
fn fill_subject(&self, i: usize, out: &mut [f64]) {
assert_eq!(out.len(), Self::K * Self::K);
let k = Self::K;
for a in 0..k {
for b in 0..k {
out[a * k + b] = self.h[[i, a, b]];
}
}
}
fn evaluate_full(&self) -> ndarray::Array3<f64> {
self.h.clone()
}
}
pub fn survival_location_scale_block_effective_jacobian(
specs: &[ParameterBlockSpec],
block_idx: usize,
) -> Result<Box<dyn BlockEffectiveJacobian>, String> {
SurvivalLocationScaleFamily::block_effective_jacobian(specs, block_idx)
}
impl CustomFamily for SurvivalLocationScaleFamily {
fn joint_jeffreys_term_required(&self) -> bool {
true
}
fn joint_jeffreys_information_directional_derivative_all_axes_with_specs(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
) -> Result<Option<Vec<Array2<f64>>>, String> {
let p_total = specs.iter().map(|spec| spec.design.ncols()).sum::<usize>();
if p_total == 0 {
return Ok(None);
}
let log_rescale = self.hessian_deriv_log_rescale(block_states);
let q = self.collect_joint_quantities_rescaled(block_states, log_rescale)?;
let dynamic = self.build_dynamic_geometry(block_states)?;
if self.row_kernel_directional_supported() {
let kernel = self.survival_ls_row_kernel_rescaled(&q, &dynamic, log_rescale);
let rows = crate::row_kernel::RowSet::All;
let axes = crate::row_kernel::row_kernel_directional_derivative_all_axes(
&kernel, &rows,
)?;
return Ok(Some(axes));
}
let mut axes = Vec::with_capacity(p_total);
for a in 0..p_total {
let mut e_a = Array1::<f64>::zeros(p_total);
e_a[a] = 1.0;
match self.exact_newton_joint_hessian_directional_derivative_rescaled_from_parts(
&e_a,
&q,
&dynamic,
log_rescale,
)? {
Some(m) => axes.push(m),
None => return Ok(None),
}
}
Ok(Some(axes))
}
fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
true
}
fn output_channel_assignment(&self, specs: &[ParameterBlockSpec]) -> Option<Vec<usize>> {
Some(
specs
.iter()
.map(|spec| match spec.name.as_str() {
"time_transform" => 0,
"threshold" => 1,
"log_sigma" => 2,
_ => 0,
})
.collect(),
)
}
fn coefficient_hessian_cost(&self, specs: &[crate::custom_family::ParameterBlockSpec]) -> u64 {
crate::custom_family::joint_coupled_coefficient_hessian_cost(self.n as u64, specs)
}
fn outer_hyper_hessian_hvp_available(
&self,
specs: &[crate::custom_family::ParameterBlockSpec],
) -> bool {
self.validate_joint_specs(
specs,
"SurvivalLocationScaleFamily outer hyper Hessian HVP availability",
)
.is_ok()
}
fn outer_hyper_hessian_dense_available(
&self,
specs: &[crate::custom_family::ParameterBlockSpec],
) -> bool {
let p_total: usize = specs.iter().map(|spec| spec.design.ncols()).sum();
!crate::custom_family::use_joint_matrix_free_path(p_total, self.n)
}
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
if self.row_kernel_joint_hessian_supported() {
let q = self.collect_joint_quantities(block_states)?;
let dynamic = self.build_dynamic_geometry(block_states)?;
let kernel = self.survival_ls_row_kernel(&q, &dynamic);
let rows = crate::row_kernel::RowSet::All;
let cache = crate::row_kernel::build_row_kernel_cache(&kernel, &rows)?;
let ll = crate::row_kernel::row_kernel_log_likelihood(&cache, &rows);
let gradient =
-crate::row_kernel::row_kernel_gradient(&kernel, &cache, &rows);
let hessian =
crate::row_kernel::row_kernel_hessian_dense(&kernel, &cache, &rows);
let offsets = self.joint_block_offsets();
let blockworking_sets = (0..self.expected_blocks())
.map(|block_idx| {
let start = offsets[block_idx];
let end = offsets[block_idx + 1];
BlockWorkingSet::ExactNewton {
gradient: gradient.slice(s![start..end]).to_owned(),
hessian: SymmetricMatrix::Dense(
hessian.slice(s![start..end, start..end]).to_owned(),
),
}
})
.collect();
return Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets,
});
}
let (ll, block_gradients) =
self.evaluate_log_likelihood_and_block_gradients(block_states)?;
let q = self.collect_joint_quantities(block_states)?;
let block_hessians =
self.assemble_block_diagonal_hessians_from_quantities(&q, block_states)?;
if block_hessians.len() != block_gradients.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily evaluate block count mismatch: gradients={}, hessians={}",
block_gradients.len(),
block_hessians.len()
) }.into());
}
let blockworking_sets = block_gradients
.into_iter()
.zip(block_hessians)
.map(|(gradient, hessian)| BlockWorkingSet::ExactNewton {
gradient,
hessian: SymmetricMatrix::Dense(hessian),
})
.collect();
Ok(FamilyEvaluation {
log_likelihood: ll,
blockworking_sets,
})
}
fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
let n = self.n;
let dynamic = self.build_dynamic_geometry(block_states)?;
let row_log_likelihood = |i: usize| -> Result<f64, String> {
let state = self.row_predictor_state(
dynamic.h_entry[i],
dynamic.h_exit[i],
dynamic.hdot_exit[i],
dynamic.q_entry[i],
dynamic.q_exit[i],
dynamic.qdot_exit[i],
);
Ok(self
.exact_row_kernel(i, state)?
.map_or(0.0, SurvivalExactRowKernel::log_likelihood))
};
const PARALLEL_LOG_LIKELIHOOD_ROW_THRESHOLD: usize = 1024;
const LOG_LIKELIHOOD_CHUNK_ROWS: usize = 1024;
if n < PARALLEL_LOG_LIKELIHOOD_ROW_THRESHOLD {
let mut ll = 0.0;
for i in 0..n {
ll += row_log_likelihood(i)?;
}
return Ok(ll);
}
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let chunk_sums: Vec<Result<f64, String>> = (0..n.div_ceil(LOG_LIKELIHOOD_CHUNK_ROWS))
.into_par_iter()
.map(|chunk_idx| {
let start = chunk_idx * LOG_LIKELIHOOD_CHUNK_ROWS;
let end = (start + LOG_LIKELIHOOD_CHUNK_ROWS).min(n);
let mut ll = 0.0;
for i in start..end {
ll += row_log_likelihood(i)?;
}
Ok(ll)
})
.collect();
let mut ll = 0.0;
for chunk_sum in chunk_sums {
ll += chunk_sum?;
}
Ok(ll)
}
fn log_likelihood_only_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<f64, String> {
let Some(subsample) = options.outer_score_subsample.as_ref() else {
return self.log_likelihood_only(block_states);
};
let n = self.n;
let dynamic = self.build_dynamic_geometry(block_states)?;
let mut ll = 0.0;
for row in subsample.rows.as_ref() {
let i = row.index;
if i >= n {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"SurvivalLocationScaleFamily outer subsample row index {i} out of bounds for n={n}"
),
}
.into());
}
let state = self.row_predictor_state(
dynamic.h_entry[i],
dynamic.h_exit[i],
dynamic.hdot_exit[i],
dynamic.q_entry[i],
dynamic.q_exit[i],
dynamic.qdot_exit[i],
);
ll += row.weight
* self
.exact_row_kernel(i, state)?
.map_or(0.0, SurvivalExactRowKernel::log_likelihood);
}
Ok(ll)
}
fn exact_newton_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
d_beta: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let dims = self.joint_block_dims();
if block_idx >= dims.len() {
return Ok(None);
}
if d_beta.len() != dims[block_idx] {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"block {block_idx} d_beta length mismatch: got {}, expected {}",
d_beta.len(),
dims[block_idx]
),
}
.into());
}
let offsets = self.joint_block_offsets();
let mut d_beta_flat = Array1::<f64>::zeros(*offsets.last().unwrap());
d_beta_flat
.slice_mut(s![offsets[block_idx]..offsets[block_idx + 1]])
.assign(d_beta);
let d_joint = self
.exact_newton_joint_hessian_directional_derivative_rescaled(
block_states,
&d_beta_flat,
0.0,
)?
.ok_or_else(|| {
"missing survival location-scale exact joint directional Hessian".to_string()
})?;
Ok(Some(
d_joint
.slice(s![
offsets[block_idx]..offsets[block_idx + 1],
offsets[block_idx]..offsets[block_idx + 1]
])
.to_owned(),
))
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let q = self.collect_joint_quantities(block_states)?;
if self.x_link_wiggle.is_some() {
let dynamic = self.build_dynamic_geometry(block_states)?;
return Ok(Some(
super::row_kernel::survival_ls_wiggle_joint_hessian_dense(self, &q, &dynamic, 0.0)?,
));
}
if self.row_kernel_joint_hessian_supported() {
let dynamic = self.build_dynamic_geometry(block_states)?;
let kernel = self.survival_ls_row_kernel(&q, &dynamic);
let rows = crate::row_kernel::RowSet::All;
let cache = crate::row_kernel::build_row_kernel_cache(&kernel, &rows)?;
return Ok(Some(crate::row_kernel::row_kernel_hessian_dense(
&kernel, &cache, &rows,
)));
}
self.assemble_joint_hessian_from_quantities(&q, block_states)
}
fn exact_newton_joint_loglik_gradient(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array1<f64>>, String> {
if !self.row_kernel_joint_hessian_supported() {
return Ok(None);
}
let q = self.collect_joint_quantities(block_states)?;
let dynamic = self.build_dynamic_geometry(block_states)?;
let kernel = self.survival_ls_row_kernel(&q, &dynamic);
let rows = crate::row_kernel::RowSet::All;
let cache = crate::row_kernel::build_row_kernel_cache(&kernel, &rows)?;
let nll_grad = crate::row_kernel::row_kernel_gradient(&kernel, &cache, &rows);
Ok(Some(-nll_grad))
}
fn exact_newton_joint_gradient_evaluation(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
let (log_likelihood, block_gradients) =
self.evaluate_log_likelihood_and_block_gradients(block_states)?;
if block_gradients.len() != specs.len() {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint gradient block count mismatch: gradients={}, specs={}",
block_gradients.len(),
specs.len()
) }.into());
}
let total_p = specs.iter().map(|spec| spec.design.ncols()).sum::<usize>();
let mut gradient = Array1::<f64>::zeros(total_p);
let mut offset = 0usize;
for (block_idx, (block_gradient, spec)) in
block_gradients.iter().zip(specs.iter()).enumerate()
{
let width = spec.design.ncols();
if block_gradient.len() != width {
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint gradient length mismatch for block {block_idx}: got {}, expected {}",
block_gradient.len(),
width
) }.into());
}
gradient
.slice_mut(s![offset..offset + width])
.assign(block_gradient);
offset += width;
}
Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood,
gradient,
}))
}
fn has_explicit_joint_hessian(&self) -> bool {
true
}
fn exact_newton_outer_curvature(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<ExactNewtonOuterCurvature>, String> {
Ok(self
.exact_newton_joint_hessian_rescaled(block_states)?
.map(|(hessian, log_scale)| {
let p = hessian.nrows();
ExactNewtonOuterCurvature {
hessian,
rho_curvature_scale: (-log_scale).exp(),
hessian_logdet_correction: p as f64 * log_scale,
}
}))
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.exact_newton_joint_hessian_directional_derivative_rescaled(
block_states,
d_beta_flat,
self.hessian_deriv_log_rescale(block_states),
)
}
fn exact_newton_joint_psi_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
self.exact_newton_joint_psi_terms_masked(
block_states,
specs,
derivative_blocks,
psi_index,
None,
)
}
fn exact_newton_joint_psisecond_order_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if block_states.len() != self.expected_blocks()
|| derivative_blocks.len() != self.expected_blocks()
{
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint psi second-order terms expect {} states and derivative blocks, got {} / {}",
self.expected_blocks(),
block_states.len(),
derivative_blocks.len()
) }.into());
}
self.validate_joint_specs(
specs,
"SurvivalLocationScaleFamily joint psi second-order terms",
)?;
let psi_dim = derivative_blocks.iter().map(Vec::len).sum::<usize>();
if psi_i >= psi_dim || psi_j >= psi_dim {
return Ok(None);
}
Ok(None)
}
fn exact_newton_joint_psi_workspace(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
if block_states.len() != self.expected_blocks()
|| specs.len() != self.expected_blocks()
|| derivative_blocks.len() != self.expected_blocks()
{
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint psi workspace expects {} states, specs, and derivative blocks, got {} / {} / {}",
self.expected_blocks(),
block_states.len(),
specs.len(),
derivative_blocks.len()
) }.into());
}
Ok(Some(Arc::new(SurvivalExactNewtonJointPsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
)?)))
}
fn exact_newton_joint_psi_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
if block_states.len() != self.expected_blocks()
|| specs.len() != self.expected_blocks()
|| derivative_blocks.len() != self.expected_blocks()
{
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint psi workspace expects {} states, specs, and derivative blocks, got {} / {} / {}",
self.expected_blocks(),
block_states.len(),
specs.len(),
derivative_blocks.len()
) }.into());
}
let mut workspace = SurvivalExactNewtonJointPsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
)?;
if let Some(subsample) = options.outer_score_subsample.as_ref() {
workspace.apply_outer_subsample(subsample.rows.as_ref());
}
Ok(Some(Arc::new(workspace)))
}
fn exact_newton_joint_psihessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if block_states.len() != self.expected_blocks()
|| derivative_blocks.len() != self.expected_blocks()
{
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint psi Hessian directional derivative expects {} states and derivative blocks, got {} / {}",
self.expected_blocks(),
block_states.len(),
derivative_blocks.len()
) }.into());
}
self.validate_joint_specs(
specs,
"SurvivalLocationScaleFamily joint psi Hessian directional derivative",
)?;
let p_total = *self
.joint_block_offsets()
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
if d_beta_flat.len() != p_total {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"joint psi Hessian directional derivative d_beta length mismatch: got {}, expected {p_total}",
d_beta_flat.len()
),
}
.into());
}
let psi_dim = derivative_blocks.iter().map(Vec::len).sum::<usize>();
if psi_index >= psi_dim {
return Ok(None);
}
Ok(None)
}
fn exact_newton_joint_hessiansecond_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
crate::block_layout::block_count::validate_block_count::<
SurvivalLocationScaleError,
>(
"SurvivalLocationScaleFamily joint Hessian second directional derivative",
self.expected_blocks(),
block_states.len(),
)?;
let p_total = *self
.joint_block_offsets()
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
if d_beta_u_flat.len() != p_total || d_beta_v_flat.len() != p_total {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"joint Hessian second directional derivative length mismatch: got {} / {}, expected {p_total}",
d_beta_u_flat.len(),
d_beta_v_flat.len()
),
}
.into());
}
let log_rescale = self.hessian_deriv_log_rescale(block_states);
let q = self.collect_joint_quantities_rescaled(block_states, log_rescale)?;
let dynamic = self.build_dynamic_geometry(block_states)?;
if self.x_link_wiggle.is_some() {
return Ok(Some(
super::row_kernel::survival_ls_wiggle_second_directional_derivative_dense(
self,
&q,
&dynamic,
log_rescale,
&crate::row_kernel::RowSet::All,
d_beta_u_flat.as_slice().ok_or_else(|| {
"joint Hessian second directional u must be contiguous".to_string()
})?,
d_beta_v_flat.as_slice().ok_or_else(|| {
"joint Hessian second directional v must be contiguous".to_string()
})?,
)?,
));
}
let kernel = self.survival_ls_row_kernel_rescaled(&q, &dynamic, log_rescale);
crate::row_kernel::row_kernel_second_directional_derivative(
&kernel,
&crate::row_kernel::RowSet::All,
d_beta_u_flat.as_slice().ok_or_else(|| {
"joint Hessian second directional u must be contiguous".to_string()
})?,
d_beta_v_flat.as_slice().ok_or_else(|| {
"joint Hessian second directional v must be contiguous".to_string()
})?,
)
.map(Some)
}
fn block_linear_constraints(
&self,
_: &[ParameterBlockState],
block_idx: usize,
spec: &ParameterBlockSpec,
) -> Result<Option<LinearInequalityConstraints>, String> {
if block_idx == Self::BLOCK_LINK_WIGGLE {
return Ok(monotone_wiggle_nonnegative_constraints(spec.design.ncols()));
}
if block_idx != Self::BLOCK_TIME {
return Ok(None);
}
Ok(self.time_linear_constraints.clone())
}
fn max_feasible_step_size(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
delta: &Array1<f64>,
) -> Result<Option<f64>, String> {
if block_idx == Self::BLOCK_TIME {
return self.max_feasible_time_step(&block_states[Self::BLOCK_TIME].beta, delta);
}
if block_idx == Self::BLOCK_LINK_WIGGLE {
return self
.max_feasible_link_wiggle_step(&block_states[Self::BLOCK_LINK_WIGGLE].beta, delta);
}
Ok(None)
}
fn joint_trust_metric_block_floor(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
) -> Result<Option<Array1<f64>>, String> {
let offsets = self.joint_block_offsets();
if offsets.len() < 2 {
return Ok(None);
}
let p_total = *offsets
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
let log_scale = self.hessian_deriv_log_rescale(block_states);
let q = self.collect_joint_quantities_rescaled(block_states, log_scale)?;
let Some(h_joint) = self.assemble_joint_hessian_from_quantities(&q, block_states)? else {
return Ok(None);
};
if h_joint.nrows() != p_total {
return Ok(None);
}
let mut floor = Array1::<f64>::zeros(p_total);
let mut any = false;
for &block in &[Self::BLOCK_THRESHOLD, Self::BLOCK_LOG_SIGMA] {
if block + 1 >= offsets.len() {
continue;
}
let (start, end) = (offsets[block], offsets[block + 1]);
if end <= start {
continue;
}
let max_diag = (start..end)
.map(|j| h_joint[[j, j]].abs())
.filter(|v| v.is_finite())
.fold(0.0_f64, f64::max);
if !(max_diag.is_finite() && max_diag > 0.0) {
continue;
}
let floor_value = SCALE_COUPLED_TRUST_METRIC_FLOOR_REL * max_diag;
if !(floor_value.is_finite() && floor_value > 0.0) {
continue;
}
for j in start..end {
floor[j] = floor_value;
}
any = true;
}
if any { Ok(Some(floor)) } else { Ok(None) }
}
fn post_update_block_beta(
&self,
_: &[ParameterBlockState],
block_idx: usize,
block_spec: &ParameterBlockSpec,
beta: Array1<f64>,
) -> Result<Array1<f64>, String> {
assert!(!block_spec.name.is_empty());
if block_idx == Self::BLOCK_TIME
&& let Some(constraints) = self.time_linear_constraints.as_ref()
{
validate_linear_constraints("time post-update", &beta, constraints)?;
} else if block_idx == Self::BLOCK_LINK_WIGGLE && self.x_link_wiggle.is_some() {
for j in 0..beta.len() {
let tol = CONSTRAINT_NONNEGATIVITY_REL_TOL * beta[j].abs().max(1.0);
if !beta[j].is_finite() || beta[j] < -tol {
return Err(SurvivalLocationScaleError::ConstraintViolation {
reason: format!(
"survival location-scale link-wiggle post-update violates represented nonnegativity at coefficient {j}: value={:.3e}, tol={:.3e}",
beta[j], tol
),
}
.into());
}
}
}
Ok(beta)
}
fn exact_newton_joint_hessian_workspace(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
self.validate_joint_specs(specs, "SurvivalLocationScaleFamily joint Hessian workspace")?;
Ok(Some(Arc::new(
SurvivalLocationScaleExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
)?,
)))
}
fn exact_newton_joint_hessian_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
self.validate_joint_specs(
specs,
"SurvivalLocationScaleFamily joint Hessian workspace with options",
)?;
let mut workspace = SurvivalLocationScaleExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
)?;
if let Some(subsample) = options.outer_score_subsample.as_ref() {
workspace.apply_outer_subsample(subsample.rows.as_ref());
}
Ok(Some(Arc::new(workspace)))
}
fn outer_derivative_subsample_capable(&self) -> bool {
true
}
}
impl SurvivalLocationScaleFamily {
pub(crate) fn exact_newton_joint_psi_terms_masked(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
psi_index: usize,
row_mask: Option<&Array1<f64>>,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if specs.len() != self.expected_blocks()
|| derivative_blocks.len() != self.expected_blocks()
{
return Err(SurvivalLocationScaleError::DimensionMismatch { reason: format!(
"SurvivalLocationScaleFamily joint psi terms expect {} specs and derivative blocks, got {} and {}",
self.expected_blocks(),
specs.len(),
derivative_blocks.len()
) }.into());
}
let Some(dir) =
self.exact_newton_joint_psi_direction(block_states, derivative_blocks, psi_index)?
else {
return Ok(None);
};
let z_t_exit_psi = &dir.z_t_exit_psi;
let z_t_entry_psi = &dir.z_t_entry_psi;
let z_ls_exit_psi = &dir.z_ls_exit_psi;
let z_ls_entry_psi = &dir.z_ls_entry_psi;
let q = self.collect_joint_quantities(block_states)?;
let dynamic = self.build_dynamic_geometry(block_states)?;
let offsets = self.joint_block_offsets();
let p_total = *offsets
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
let x_threshold_exit_cow = self.x_threshold.to_dense_cow();
let x_threshold_exit = &*x_threshold_exit_cow;
let x_threshold_entry_cow = self
.x_threshold_entry
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_threshold_entry = x_threshold_entry_cow
.as_ref()
.map_or(x_threshold_exit, |c| &**c);
let x_log_sigma_exit_cow = self.x_log_sigma.to_dense_cow();
let x_log_sigma_exit = &*x_log_sigma_exit_cow;
let x_log_sigma_entry_cow = self
.x_log_sigma_entry
.as_ref()
.map(DesignMatrix::to_dense_cow);
let x_log_sigma_entry = x_log_sigma_entry_cow
.as_ref()
.map_or(x_log_sigma_exit, |c| &**c);
let xw_cow = self.x_link_wiggle.as_ref().map(DesignMatrix::to_dense_cow);
let xw = xw_cow.as_deref();
let x_t_exit_map = first_psi_linear_map(
dir.x_t_exit_action.as_ref(),
dir.x_t_exit_psi.as_ref(),
self.n,
x_threshold_exit.ncols(),
);
let x_t_entry_map = first_psi_linear_map(
dir.x_t_entry_action.as_ref(),
dir.x_t_entry_psi.as_ref(),
self.n,
x_threshold_entry.ncols(),
);
let x_ls_exit_map = first_psi_linear_map(
dir.x_ls_exit_action.as_ref(),
dir.x_ls_exit_psi.as_ref(),
self.n,
x_log_sigma_exit.ncols(),
);
let x_ls_entry_map = first_psi_linear_map(
dir.x_ls_entry_action.as_ref(),
dir.x_ls_entry_psi.as_ref(),
self.n,
x_log_sigma_entry.ncols(),
);
let dq_t_entry = q.dq_t_entry.as_ref().unwrap_or(&q.dq_t);
let dq_ls_entry = q.dq_ls_entry.as_ref().unwrap_or(&q.dq_ls);
let d2q_tls_entry = q.d2q_tls_entry.as_ref().unwrap_or(&q.d2q_tls);
let d2q_ls_entry = q.d2q_ls_entry.as_ref().unwrap_or(&q.d2q_ls);
let d3q_tls_ls_entry = q.d3q_tls_ls_entry.as_ref().unwrap_or(&q.d3q_tls_ls);
let d3q_ls_entry = q.d3q_ls_entry.as_ref().unwrap_or(&q.d3q_ls);
let q0_psi = &(dq_t_entry * z_t_entry_psi) + &(dq_ls_entry * z_ls_entry_psi);
let q1_psi = &(&q.dq_t * z_t_exit_psi) + &(&q.dq_ls * z_ls_exit_psi);
let dq_t_entry_psi = d2q_tls_entry * z_ls_entry_psi;
let dq_t_exit_psi = &q.d2q_tls * z_ls_exit_psi;
let dq_ls_entry_psi = d2q_tls_entry * z_t_entry_psi + d2q_ls_entry * z_ls_entry_psi;
let dq_ls_exit_psi = &q.d2q_tls * z_t_exit_psi + &q.d2q_ls * z_ls_exit_psi;
let d2q_tls_entry_psi = d3q_tls_ls_entry * z_ls_entry_psi;
let d2q_tls_exit_psi = &q.d3q_tls_ls * z_ls_exit_psi;
let d2q_ls_entry_psi = d3q_tls_ls_entry * z_t_entry_psi + d3q_ls_entry * z_ls_entry_psi;
let d2q_ls_exit_psi = &q.d3q_tls_ls * z_t_exit_psi + &q.d3q_ls * z_ls_exit_psi;
let objective_psi = if let Some(m) = row_mask {
(&(&q.d1_q0 * &q0_psi) * m).sum() + (&(&q.d1_q1 * &q1_psi) * m).sum()
} else {
q.d1_q0.dot(&q0_psi) + q.d1_q1.dot(&q1_psi)
};
let mut score_psi = Array1::<f64>::zeros(p_total);
let time_row_entry = -&q.d2_q0 * &q0_psi;
let time_row_exit = -&q.d2_q1 * &q1_psi;
let time_score = dynamic
.time_jac_entry
.t()
.dot(&*mask_row_vec(&time_row_entry, row_mask))
+ dynamic
.time_jac_exit
.t()
.dot(&*mask_row_vec(&time_row_exit, row_mask));
score_psi
.slice_mut(s![offsets[0]..offsets[1]])
.assign(&time_score);
let threshold_score_row_exit = &q.d1_q1 * &q.dq_t;
let threshold_score_row_entry = &q.d1_q0 * dq_t_entry;
let d_threshold_score_row_exit = &q.d2_q1 * &q1_psi * &q.dq_t + &q.d1_q1 * &dq_t_exit_psi;
let d_threshold_score_row_entry =
&q.d2_q0 * &q0_psi * dq_t_entry + &q.d1_q0 * &dq_t_entry_psi;
let threshold_score = x_t_exit_map
.transpose_mul(mask_row_vec(&threshold_score_row_exit, row_mask).view())
+ x_threshold_exit
.t()
.dot(&*mask_row_vec(&d_threshold_score_row_exit, row_mask))
+ x_t_entry_map
.transpose_mul(mask_row_vec(&threshold_score_row_entry, row_mask).view())
+ x_threshold_entry
.t()
.dot(&*mask_row_vec(&d_threshold_score_row_entry, row_mask));
score_psi
.slice_mut(s![offsets[1]..offsets[2]])
.assign(&threshold_score);
let log_sigma_score_row_exit = &q.d1_q1 * &q.dq_ls;
let log_sigma_score_row_entry = &q.d1_q0 * dq_ls_entry;
let d_log_sigma_score_row_exit = &q.d2_q1 * &q1_psi * &q.dq_ls + &q.d1_q1 * &dq_ls_exit_psi;
let d_log_sigma_score_row_entry =
&q.d2_q0 * &q0_psi * dq_ls_entry + &q.d1_q0 * &dq_ls_entry_psi;
let log_sigma_score = x_ls_exit_map
.transpose_mul(mask_row_vec(&log_sigma_score_row_exit, row_mask).view())
+ x_log_sigma_exit
.t()
.dot(&*mask_row_vec(&d_log_sigma_score_row_exit, row_mask))
+ x_ls_entry_map
.transpose_mul(mask_row_vec(&log_sigma_score_row_entry, row_mask).view())
+ x_log_sigma_entry
.t()
.dot(&*mask_row_vec(&d_log_sigma_score_row_entry, row_mask));
score_psi
.slice_mut(s![offsets[2]..offsets[3]])
.assign(&log_sigma_score);
if let (Some(xw_dense), Some(w_offset)) = (xw, offsets.get(3).copied()) {
let wiggle_row = &q.d2_q0 * &q0_psi + &q.d2_q1 * &q1_psi;
let wiggle_score = xw_dense.t().dot(&*mask_row_vec(&wiggle_row, row_mask));
score_psi
.slice_mut(s![w_offset..offsets[4]])
.assign(&wiggle_score);
}
let h_time_time = mxtwxd(&dynamic.time_jac_entry, &(-&q.d3_q0 * &q0_psi), row_mask)
+ mxtwxd(&dynamic.time_jac_exit, &(-&q.d3_q1 * &q1_psi), row_mask);
let h_tt_entry = -(&q.d2_q0 * &dq_t_entry.mapv(|v| safe_product(v, v)));
let h_tt_exit = -(&q.d2_q1 * &q.dq_t.mapv(|v| safe_product(v, v)));
let dh_tt_entry = -(&q.d3_q0 * &q0_psi * &dq_t_entry.mapv(|v| safe_product(v, v))
+ &(2.0 * &q.d2_q0 * dq_t_entry * &dq_t_entry_psi));
let dh_tt_exit = -(&q.d3_q1 * &q1_psi * &q.dq_t.mapv(|v| safe_product(v, v))
+ &(2.0 * &q.d2_q1 * &q.dq_t * &dq_t_exit_psi));
let h_ll_entry =
-(&q.d2_q0 * &dq_ls_entry.mapv(|v| safe_product(v, v)) + &(&q.d1_q0 * d2q_ls_entry));
let h_ll_exit =
-(&q.d2_q1 * &q.dq_ls.mapv(|v| safe_product(v, v)) + &(&q.d1_q1 * &q.d2q_ls));
let dh_ll_entry = -(&q.d3_q0 * &q0_psi * &dq_ls_entry.mapv(|v| safe_product(v, v))
+ &(2.0 * &q.d2_q0 * dq_ls_entry * &dq_ls_entry_psi)
+ &(&q.d2_q0 * &q0_psi * d2q_ls_entry)
+ &(&q.d1_q0 * &d2q_ls_entry_psi));
let dh_ll_exit = -(&q.d3_q1 * &q1_psi * &q.dq_ls.mapv(|v| safe_product(v, v))
+ &(2.0 * &q.d2_q1 * &q.dq_ls * &dq_ls_exit_psi)
+ &(&q.d2_q1 * &q1_psi * &q.d2q_ls)
+ &(&q.d1_q1 * &d2q_ls_exit_psi));
let h_tl_entry = -(&q.d2_q0 * &(dq_t_entry * dq_ls_entry) + &(&q.d1_q0 * d2q_tls_entry));
let h_tl_exit = -(&q.d2_q1 * &(&q.dq_t * &q.dq_ls) + &(&q.d1_q1 * &q.d2q_tls));
let dh_tl_entry = -(&q.d3_q0 * &q0_psi * &(dq_t_entry * dq_ls_entry)
+ &(&q.d2_q0 * &(&dq_t_entry_psi * dq_ls_entry + dq_t_entry * &dq_ls_entry_psi))
+ &(&q.d2_q0 * &q0_psi * d2q_tls_entry)
+ &(&q.d1_q0 * &d2q_tls_entry_psi));
let dh_tl_exit = -(&q.d3_q1 * &q1_psi * &(&q.dq_t * &q.dq_ls)
+ &(&q.d2_q1 * &(&dq_t_exit_psi * &q.dq_ls + &q.dq_t * &dq_ls_exit_psi))
+ &(&q.d2_q1 * &q1_psi * &q.d2q_tls)
+ &(&q.d1_q1 * &d2q_tls_exit_psi));
let h_h0_t = &q.d2_q0 * dq_t_entry;
let h_h1_t = &q.d2_q1 * &q.dq_t;
let dh_h0_t = &q.d3_q0 * &q0_psi * dq_t_entry + &q.d2_q0 * &dq_t_entry_psi;
let dh_h1_t = &q.d3_q1 * &q1_psi * &q.dq_t + &q.d2_q1 * &dq_t_exit_psi;
let h_h0_ls = &q.d2_q0 * dq_ls_entry;
let h_h1_ls = &q.d2_q1 * &q.dq_ls;
let dh_h0_ls = &q.d3_q0 * &q0_psi * dq_ls_entry + &q.d2_q0 * &dq_ls_entry_psi;
let dh_h1_ls = &q.d3_q1 * &q1_psi * &q.dq_ls + &q.d2_q1 * &dq_ls_exit_psi;
let h_tw_entry = -(&q.d2_q0 * dq_t_entry);
let h_tw_exit = -(&q.d2_q1 * &q.dq_t);
let dh_tw_entry = -(&q.d3_q0 * &q0_psi * dq_t_entry + &q.d2_q0 * &dq_t_entry_psi);
let dh_tw_exit = -(&q.d3_q1 * &q1_psi * &q.dq_t + &q.d2_q1 * &dq_t_exit_psi);
let h_lw_entry = -(&q.d2_q0 * dq_ls_entry);
let h_lw_exit = -(&q.d2_q1 * &q.dq_ls);
let dh_lw_entry = -(&q.d3_q0 * &q0_psi * dq_ls_entry + &q.d2_q0 * &dq_ls_entry_psi);
let dh_lw_exit = -(&q.d3_q1 * &q1_psi * &q.dq_ls + &q.d2_q1 * &dq_ls_exit_psi);
if dir.x_t_exit_action.is_some()
|| dir.x_t_entry_action.is_some()
|| dir.x_ls_exit_action.is_some()
|| dir.x_ls_entry_action.is_some()
{
let mw = |arr: Array1<f64>| -> Array1<f64> {
match row_mask {
Some(m) => &arr * m,
None => arr,
}
};
let mut channels = vec![
CustomFamilyJointDesignChannel::new(
offsets[0]..offsets[1],
shared_dense_arc(&self.x_time_entry),
None,
),
CustomFamilyJointDesignChannel::new(
offsets[0]..offsets[1],
shared_dense_arc(&self.x_time_exit),
None,
),
CustomFamilyJointDesignChannel::new(
offsets[1]..offsets[2],
shared_dense_arc(x_threshold_exit),
dir.x_t_exit_action.clone(),
),
CustomFamilyJointDesignChannel::new(
offsets[1]..offsets[2],
shared_dense_arc(x_threshold_entry),
dir.x_t_entry_action.clone(),
),
CustomFamilyJointDesignChannel::new(
offsets[2]..offsets[3],
shared_dense_arc(x_log_sigma_exit),
dir.x_ls_exit_action.clone(),
),
CustomFamilyJointDesignChannel::new(
offsets[2]..offsets[3],
shared_dense_arc(x_log_sigma_entry),
dir.x_ls_entry_action.clone(),
),
];
let mut pairs = vec![
CustomFamilyJointDesignPairContribution::new(
0,
0,
mw(Array1::zeros(self.x_time_entry.nrows())),
mw(-&q.d3_q0 * &q0_psi),
),
CustomFamilyJointDesignPairContribution::new(
1,
1,
mw(Array1::zeros(self.x_time_exit.nrows())),
mw(-&q.d3_q1 * &q1_psi),
),
CustomFamilyJointDesignPairContribution::new(
2,
2,
mw(h_tt_exit.clone()),
mw(dh_tt_exit.clone()),
),
CustomFamilyJointDesignPairContribution::new(
3,
3,
mw(h_tt_entry.clone()),
mw(dh_tt_entry.clone()),
),
CustomFamilyJointDesignPairContribution::new(
4,
4,
mw(h_ll_exit.clone()),
mw(dh_ll_exit.clone()),
),
CustomFamilyJointDesignPairContribution::new(
5,
5,
mw(h_ll_entry.clone()),
mw(dh_ll_entry.clone()),
),
CustomFamilyJointDesignPairContribution::new(
2,
4,
mw(h_tl_exit.clone()),
mw(dh_tl_exit.clone()),
),
CustomFamilyJointDesignPairContribution::new(
4,
2,
mw(h_tl_exit.clone()),
mw(dh_tl_exit.clone()),
),
CustomFamilyJointDesignPairContribution::new(
3,
5,
mw(h_tl_entry.clone()),
mw(dh_tl_entry.clone()),
),
CustomFamilyJointDesignPairContribution::new(
5,
3,
mw(h_tl_entry.clone()),
mw(dh_tl_entry.clone()),
),
CustomFamilyJointDesignPairContribution::new(
0,
3,
mw(h_h0_t.clone()),
mw(dh_h0_t.clone()),
),
CustomFamilyJointDesignPairContribution::new(
3,
0,
mw(h_h0_t.clone()),
mw(dh_h0_t.clone()),
),
CustomFamilyJointDesignPairContribution::new(
1,
2,
mw(h_h1_t.clone()),
mw(dh_h1_t.clone()),
),
CustomFamilyJointDesignPairContribution::new(
2,
1,
mw(h_h1_t.clone()),
mw(dh_h1_t.clone()),
),
CustomFamilyJointDesignPairContribution::new(
0,
5,
mw(h_h0_ls.clone()),
mw(dh_h0_ls.clone()),
),
CustomFamilyJointDesignPairContribution::new(
5,
0,
mw(h_h0_ls.clone()),
mw(dh_h0_ls.clone()),
),
CustomFamilyJointDesignPairContribution::new(
1,
4,
mw(h_h1_ls.clone()),
mw(dh_h1_ls.clone()),
),
CustomFamilyJointDesignPairContribution::new(
4,
1,
mw(h_h1_ls.clone()),
mw(dh_h1_ls.clone()),
),
];
if let (Some(xw_dense), Some(w_offset)) = (xw, offsets.get(3).copied()) {
channels.push(CustomFamilyJointDesignChannel::new(
w_offset..offsets[4],
shared_dense_arc(xw_dense),
None,
));
let w_idx = channels.len() - 1;
let zero_w = Array1::zeros(xw_dense.nrows());
pairs.push(CustomFamilyJointDesignPairContribution::new(
w_idx,
w_idx,
mw(zero_w.clone()),
mw(-&q.d3_q0 * &q0_psi - &q.d3_q1 * &q1_psi),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
2,
w_idx,
mw(h_tw_exit.clone()),
mw(dh_tw_exit.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
w_idx,
2,
mw(h_tw_exit.clone()),
mw(dh_tw_exit.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
3,
w_idx,
mw(h_tw_entry.clone()),
mw(dh_tw_entry.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
w_idx,
3,
mw(h_tw_entry.clone()),
mw(dh_tw_entry.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
4,
w_idx,
mw(h_lw_exit.clone()),
mw(dh_lw_exit.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
w_idx,
4,
mw(h_lw_exit.clone()),
mw(dh_lw_exit.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
5,
w_idx,
mw(h_lw_entry.clone()),
mw(dh_lw_entry.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
w_idx,
5,
mw(h_lw_entry.clone()),
mw(dh_lw_entry.clone()),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
0,
w_idx,
mw(zero_w.clone()),
mw(&q.d3_q0 * &q0_psi),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
w_idx,
0,
mw(zero_w.clone()),
mw(&q.d3_q0 * &q0_psi),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
1,
w_idx,
mw(zero_w.clone()),
mw(&q.d3_q1 * &q1_psi),
));
pairs.push(CustomFamilyJointDesignPairContribution::new(
w_idx,
1,
mw(zero_w),
mw(&q.d3_q1 * &q1_psi),
));
}
return Ok(Some(ExactNewtonJointPsiTerms {
objective_psi,
score_psi,
hessian_psi: Array2::zeros((0, 0)),
hessian_psi_operator: Some(std::sync::Arc::new(CustomFamilyJointPsiOperator::new(
p_total, channels, pairs,
))),
}));
}
let mut hessian_psi = Array2::<f64>::zeros((p_total, p_total));
assign_symmetric_block(&mut hessian_psi, offsets[0], offsets[0], &h_time_time);
let h_threshold_threshold =
mxtwx_psi(
x_t_exit_map,
h_tt_exit.view(),
CustomFamilyPsiLinearMapRef::Dense(x_threshold_exit),
row_mask,
)? + mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(x_threshold_exit),
h_tt_exit.view(),
x_t_exit_map,
row_mask,
)? + mxtwx(x_threshold_exit, &dh_tt_exit, x_threshold_exit, row_mask)?
+ mxtwx_psi(
x_t_entry_map,
h_tt_entry.view(),
CustomFamilyPsiLinearMapRef::Dense(x_threshold_entry),
row_mask,
)?
+ mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(x_threshold_entry),
h_tt_entry.view(),
x_t_entry_map,
row_mask,
)?
+ mxtwx(x_threshold_entry, &dh_tt_entry, x_threshold_entry, row_mask)?;
assign_symmetric_block(
&mut hessian_psi,
offsets[1],
offsets[1],
&h_threshold_threshold,
);
let h_log_sigma_log_sigma =
mxtwx_psi(
x_ls_exit_map,
h_ll_exit.view(),
CustomFamilyPsiLinearMapRef::Dense(x_log_sigma_exit),
row_mask,
)? + mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(x_log_sigma_exit),
h_ll_exit.view(),
x_ls_exit_map,
row_mask,
)? + mxtwx(x_log_sigma_exit, &dh_ll_exit, x_log_sigma_exit, row_mask)?
+ mxtwx_psi(
x_ls_entry_map,
h_ll_entry.view(),
CustomFamilyPsiLinearMapRef::Dense(x_log_sigma_entry),
row_mask,
)?
+ mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(x_log_sigma_entry),
h_ll_entry.view(),
x_ls_entry_map,
row_mask,
)?
+ mxtwx(x_log_sigma_entry, &dh_ll_entry, x_log_sigma_entry, row_mask)?;
assign_symmetric_block(
&mut hessian_psi,
offsets[2],
offsets[2],
&h_log_sigma_log_sigma,
);
let h_threshold_log_sigma =
mxtwx_psi(
x_t_exit_map,
h_tl_exit.view(),
CustomFamilyPsiLinearMapRef::Dense(x_log_sigma_exit),
row_mask,
)? + mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(x_threshold_exit),
h_tl_exit.view(),
x_ls_exit_map,
row_mask,
)? + mxtwx(x_threshold_exit, &dh_tl_exit, x_log_sigma_exit, row_mask)?
+ mxtwx_psi(
x_t_entry_map,
h_tl_entry.view(),
CustomFamilyPsiLinearMapRef::Dense(x_log_sigma_entry),
row_mask,
)?
+ mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(x_threshold_entry),
h_tl_entry.view(),
x_ls_entry_map,
row_mask,
)?
+ mxtwx(x_threshold_entry, &dh_tl_entry, x_log_sigma_entry, row_mask)?;
assign_symmetric_block(
&mut hessian_psi,
offsets[1],
offsets[2],
&h_threshold_log_sigma,
);
let h_time_threshold = mxtwx(&self.x_time_entry, &dh_h0_t, x_threshold_entry, row_mask)?
+ mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(&self.x_time_entry),
h_h0_t.view(),
x_t_entry_map,
row_mask,
)?
+ mxtwx(&self.x_time_exit, &dh_h1_t, x_threshold_exit, row_mask)?
+ mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(&self.x_time_exit),
h_h1_t.view(),
x_t_exit_map,
row_mask,
)?;
assign_symmetric_block(&mut hessian_psi, offsets[0], offsets[1], &h_time_threshold);
let h_time_log_sigma = mxtwx(&self.x_time_entry, &dh_h0_ls, x_log_sigma_entry, row_mask)?
+ mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(&self.x_time_entry),
h_h0_ls.view(),
x_ls_entry_map,
row_mask,
)?
+ mxtwx(&self.x_time_exit, &dh_h1_ls, x_log_sigma_exit, row_mask)?
+ mxtwx_psi(
CustomFamilyPsiLinearMapRef::Dense(&self.x_time_exit),
h_h1_ls.view(),
x_ls_exit_map,
row_mask,
)?;
assign_symmetric_block(&mut hessian_psi, offsets[0], offsets[2], &h_time_log_sigma);
if let (Some(xw_dense), Some(w_offset)) = (xw, offsets.get(3).copied()) {
let h_ww = -(&q.d3_q0 * &q0_psi + &q.d3_q1 * &q1_psi);
let h_wiggle_wiggle = mxtwx(xw_dense, &h_ww, xw_dense, row_mask)?;
assign_symmetric_block(&mut hessian_psi, w_offset, w_offset, &h_wiggle_wiggle);
let h_threshold_wiggle = mxtwx_psi(
x_t_exit_map,
h_tw_exit.view(),
CustomFamilyPsiLinearMapRef::Dense(xw_dense),
row_mask,
)? + mxtwx(x_threshold_exit, &dh_tw_exit, xw_dense, row_mask)?
+ mxtwx_psi(
x_t_entry_map,
h_tw_entry.view(),
CustomFamilyPsiLinearMapRef::Dense(xw_dense),
row_mask,
)?
+ mxtwx(x_threshold_entry, &dh_tw_entry, xw_dense, row_mask)?;
assign_symmetric_block(&mut hessian_psi, offsets[1], w_offset, &h_threshold_wiggle);
let h_log_sigma_wiggle = mxtwx_psi(
x_ls_exit_map,
h_lw_exit.view(),
CustomFamilyPsiLinearMapRef::Dense(xw_dense),
row_mask,
)? + mxtwx(x_log_sigma_exit, &dh_lw_exit, xw_dense, row_mask)?
+ mxtwx_psi(
x_ls_entry_map,
h_lw_entry.view(),
CustomFamilyPsiLinearMapRef::Dense(xw_dense),
row_mask,
)?
+ mxtwx(x_log_sigma_entry, &dh_lw_entry, xw_dense, row_mask)?;
assign_symmetric_block(&mut hessian_psi, offsets[2], w_offset, &h_log_sigma_wiggle);
let h_time_wiggle =
mxtwx(
&self.x_time_entry,
&(&q.d3_q0 * &q0_psi),
xw_dense,
row_mask,
)? + mxtwx(&self.x_time_exit, &(&q.d3_q1 * &q1_psi), xw_dense, row_mask)?;
assign_symmetric_block(&mut hessian_psi, offsets[0], w_offset, &h_time_wiggle);
}
Ok(Some(ExactNewtonJointPsiTerms {
objective_psi,
score_psi,
hessian_psi,
hessian_psi_operator: None,
}))
}
}
pub(crate) struct SurvivalExactNewtonJointPsiWorkspace {
pub(crate) family: SurvivalLocationScaleFamily,
pub(crate) block_states: Vec<ParameterBlockState>,
pub(crate) specs: Vec<ParameterBlockSpec>,
pub(crate) derivative_blocks: Vec<Vec<CustomFamilyBlockPsiDerivative>>,
pub(crate) row_mask: Option<Arc<Array1<f64>>>,
}
impl SurvivalExactNewtonJointPsiWorkspace {
pub(crate) fn new(
family: SurvivalLocationScaleFamily,
block_states: Vec<ParameterBlockState>,
specs: Vec<ParameterBlockSpec>,
derivative_blocks: Vec<Vec<CustomFamilyBlockPsiDerivative>>,
) -> Result<Self, String> {
Ok(Self {
family,
block_states,
specs,
derivative_blocks,
row_mask: None,
})
}
pub(crate) fn apply_outer_subsample(
&mut self,
rows: &[crate::outer_subsample::WeightedOuterRow],
) {
let n = self.family.n;
let mut mask = Array1::<f64>::zeros(n);
for r in rows {
if r.index < n {
mask[r.index] = r.weight;
}
}
self.row_mask = Some(Arc::new(mask));
}
}
impl ExactNewtonJointPsiWorkspace for SurvivalExactNewtonJointPsiWorkspace {
fn first_order_terms(
&self,
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
self.family.exact_newton_joint_psi_terms_masked(
&self.block_states,
&self.specs,
&self.derivative_blocks,
psi_index,
self.row_mask.as_deref(),
)
}
fn second_order_terms(
&self,
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
let psi_dim = self.derivative_blocks.iter().map(Vec::len).sum::<usize>();
if psi_i >= psi_dim || psi_j >= psi_dim {
return Ok(None);
}
Ok(None)
}
fn hessian_directional_derivative(
&self,
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<gam_problem::DriftDerivResult>, String> {
let p_total = *self
.family
.joint_block_offsets()
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
if d_beta_flat.len() != p_total {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"joint psi workspace Hessian directional derivative d_beta length mismatch: got {}, expected {p_total}",
d_beta_flat.len()
),
}
.into());
}
let psi_dim = self.derivative_blocks.iter().map(Vec::len).sum::<usize>();
if psi_index >= psi_dim {
return Ok(None);
}
Ok(self
.family
.exact_newton_joint_psihessian_directional_derivative(
&self.block_states,
&self.specs,
&self.derivative_blocks,
psi_index,
d_beta_flat,
)?
.map(gam_problem::DriftDerivResult::Dense))
}
}
pub(crate) struct SurvivalLocationScaleExactNewtonJointHessianWorkspace {
pub(crate) family: SurvivalLocationScaleFamily,
pub(crate) q: SurvivalJointQuantities,
pub(crate) dynamic: SurvivalDynamicGeometry,
pub(crate) deriv_log_scale: f64,
pub(crate) row_mask: Option<Arc<Array1<f64>>>,
}
impl SurvivalLocationScaleExactNewtonJointHessianWorkspace {
pub(crate) fn new(
family: SurvivalLocationScaleFamily,
block_states: Vec<ParameterBlockState>,
) -> Result<Self, String> {
let log_rescale = family.hessian_deriv_log_rescale(&block_states);
let q = family.collect_joint_quantities_rescaled(&block_states, log_rescale)?;
let dynamic = family.build_dynamic_geometry(&block_states)?;
Ok(Self {
family,
q,
dynamic,
deriv_log_scale: log_rescale,
row_mask: None,
})
}
pub(crate) fn apply_outer_subsample(
&mut self,
rows: &[crate::outer_subsample::WeightedOuterRow],
) {
let n = self.family.n;
let mut mask = Array1::<f64>::zeros(n);
for r in rows {
if r.index < n {
mask[r.index] = r.weight;
}
}
self.row_mask = Some(Arc::new(mask));
}
}
impl ExactNewtonJointHessianWorkspace for SurvivalLocationScaleExactNewtonJointHessianWorkspace {
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative_rescaled_from_parts_masked(
d_beta_flat,
&self.q,
&self.dynamic,
self.deriv_log_scale,
self.row_mask.as_deref(),
)
}
fn directional_derivative_operator(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
Ok(self
.family
.exact_newton_joint_hessian_directional_derivative_rescaled_from_parts_masked(
d_beta_flat,
&self.q,
&self.dynamic,
self.deriv_log_scale,
self.row_mask.as_deref(),
)?
.map(|matrix| Arc::new(DenseMatrixHyperOperator { matrix }) as Arc<dyn HyperOperator>))
}
fn second_directional_derivative(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let p_total = *self
.family
.joint_block_offsets()
.last()
.ok_or_else(|| "missing joint block offsets".to_string())?;
if d_beta_u_flat.len() != p_total || d_beta_v_flat.len() != p_total {
return Err(SurvivalLocationScaleError::DimensionMismatch {
reason: format!(
"joint Hessian workspace second directional derivative length mismatch: got {} / {}, expected {p_total}",
d_beta_u_flat.len(),
d_beta_v_flat.len()
),
}
.into());
}
let rows = row_set_from_survival_mask(self.row_mask.as_deref(), self.family.n);
if self.family.x_link_wiggle.is_some() {
return Ok(Some(
super::row_kernel::survival_ls_wiggle_second_directional_derivative_dense(
&self.family,
&self.q,
&self.dynamic,
self.deriv_log_scale,
&rows,
d_beta_u_flat.as_slice().ok_or_else(|| {
"joint Hessian workspace second directional u must be contiguous"
.to_string()
})?,
d_beta_v_flat.as_slice().ok_or_else(|| {
"joint Hessian workspace second directional v must be contiguous"
.to_string()
})?,
)?,
));
}
let kernel = self.family.survival_ls_row_kernel_rescaled(
&self.q,
&self.dynamic,
self.deriv_log_scale,
);
crate::row_kernel::row_kernel_second_directional_derivative(
&kernel,
&rows,
d_beta_u_flat.as_slice().ok_or_else(|| {
"joint Hessian workspace second directional u must be contiguous".to_string()
})?,
d_beta_v_flat.as_slice().ok_or_else(|| {
"joint Hessian workspace second directional v must be contiguous".to_string()
})?,
)
.map(Some)
}
fn second_directional_derivative_operator(
&self,
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Arc<dyn HyperOperator>>, String> {
Ok(self
.second_directional_derivative(d_beta_u_flat, d_beta_v_flat)?
.map(|matrix| Arc::new(DenseMatrixHyperOperator { matrix }) as Arc<dyn HyperOperator>))
}
}