use super::*;
impl CustomFamily for SurvivalMarginalSlopeFamily {
fn joint_jeffreys_term_required(&self) -> bool {
true
}
fn levenberg_on_ill_conditioning(&self) -> bool {
true
}
fn persistent_warm_start_fingerprint(
&self,
specs: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Option<String> {
if !parameter_block_specs_match_rows(specs, self.n)
|| !options.inner_tol.is_finite()
|| options.inner_tol <= 0.0
{
return None;
}
let mut hasher = gam_runtime::warm_start::Fingerprinter::new();
hasher.write_str("survival-marginal-slope-family");
hasher.write_usize(self.n);
hasher.write_usize(self.event.len());
for &value in self.event.iter() {
hasher.write_f64(value);
}
hasher.write_usize(self.weights.len());
for &value in self.weights.iter() {
hasher.write_f64(value);
}
hasher.write_usize(self.z.nrows());
hasher.write_usize(self.z.ncols());
for &value in self.z.iter() {
hasher.write_f64(value);
}
match self.gaussian_frailty_sd {
Some(value) => {
hasher.write_bool(true);
hasher.write_f64(value);
}
None => hasher.write_bool(false),
}
hasher.write_f64(self.derivative_guard);
hasher.write_usize(self.offset_entry.len());
for &value in self.offset_entry.iter() {
hasher.write_f64(value);
}
hasher.write_usize(self.offset_exit.len());
for &value in self.offset_exit.iter() {
hasher.write_f64(value);
}
hasher.write_usize(self.derivative_offset_exit.len());
for &value in self.derivative_offset_exit.iter() {
hasher.write_f64(value);
}
Some(hasher.finish_hex())
}
fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
true
}
fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
crate::coefficient_cost::joint_coupled_operator_aware_hessian_cost(
self.n as u64,
specs,
)
}
fn outer_derivative_policy(
&self,
specs: &[ParameterBlockSpec],
psi_dim: usize,
options: &BlockwiseFitOptions,
) -> crate::custom_family::OuterDerivativePolicy {
use crate::custom_family::OuterDerivativePolicy;
let capability = self.exact_outer_derivative_order(specs, options);
let rho_dim = specs
.iter()
.map(|spec| spec.penalties.len() as u128)
.sum::<u128>();
let k = rho_dim.saturating_add(psi_dim as u128).max(1);
let predicted_hessian_work = if !self.flex_active() && !self.flex_timewiggle_active() {
let p_total = specs
.iter()
.map(|spec| spec.design.ncols() as u128)
.sum::<u128>();
(self.n as u128)
.saturating_mul(k)
.saturating_mul(p_total.saturating_add(N_PRIMARY as u128))
} else {
let (gradient_work, hessian_work) =
crate::custom_family::default_outer_derivative_policy_costs(
specs,
psi_dim,
self.coefficient_gradient_cost(specs),
self.coefficient_hessian_cost(specs),
);
return OuterDerivativePolicy {
capability,
predicted_gradient_work: gradient_work,
predicted_hessian_work: hessian_work,
subsample_capable: true,
};
};
OuterDerivativePolicy {
capability,
predicted_gradient_work: predicted_hessian_work / 2,
predicted_hessian_work,
subsample_capable: true,
}
}
fn exact_newton_joint_psi_workspace_for_first_order_terms(&self) -> bool {
true
}
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
self.evaluate_blockwise_exact_newton(block_states)
}
fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
let options = BlockwiseFitOptions {
auto_outer_subsample: false,
..BlockwiseFitOptions::default()
};
SurvivalMarginalSlopeFamily::log_likelihood_only_with_options(self, block_states, &options)
}
fn log_likelihood_only_with_options(
&self,
block_states: &[ParameterBlockState],
options: &BlockwiseFitOptions,
) -> Result<f64, String> {
let owned;
let options: &BlockwiseFitOptions = match self.install_auto_outer_subsample_options(options)
{
Some(cloned) => {
owned = cloned;
&owned
}
None => options,
};
SurvivalMarginalSlopeFamily::log_likelihood_only_with_options(self, block_states, options)
}
fn has_explicit_joint_hessian(&self) -> bool {
true
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let total = block_states
.iter()
.map(|state| state.beta.len())
.sum::<usize>();
if total >= 512 {
return Ok(None);
}
if self.per_z_logslope_active() {
return Ok(Some(
self.evaluate_exact_newton_joint_dense_per_z(block_states)?
.2,
));
}
Ok(Some(
self.evaluate_exact_newton_joint_dense(block_states)?.2,
))
}
fn exact_newton_joint_gradient_evaluation(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
if self.per_z_logslope_active() {
let (log_likelihood, gradient, _) =
self.evaluate_exact_newton_joint_dense_per_z(block_states)?;
return Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood,
gradient,
}));
}
if self.effective_flex_active(block_states)? || self.flex_timewiggle_active() {
let (log_likelihood, gradient) =
self.evaluate_exact_newton_joint_gradient_dynamic_q(block_states)?;
return Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood,
gradient,
}));
}
let kern = SurvivalMarginalSlopeRowKernel::new(self.clone(), block_states.to_vec());
let rows = crate::row_kernel::RowSet::All;
let cache = build_row_kernel_cache(&kern, &rows)?;
Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: row_kernel_log_likelihood(&cache, &rows),
gradient: -row_kernel_gradient(&kern, &cache, &rows),
}))
}
fn requires_joint_outer_hyper_path(&self) -> bool {
true
}
fn exact_newton_joint_hessian_workspace(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
if self.per_z_logslope_active() {
return Ok(None);
}
if !self.effective_flex_active(block_states)? && !self.flex_timewiggle_active() {
let kern = SurvivalMarginalSlopeRowKernel::new(self.clone(), block_states.to_vec());
return Ok(Some(Arc::new(RowKernelHessianWorkspace::new(kern)?)));
}
Ok(Some(Arc::new(
SurvivalMarginalSlopeExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
BlockwiseFitOptions::default(),
)?,
)))
}
fn exact_newton_joint_hessian_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
if self.per_z_logslope_active() {
return Ok(None);
}
if !self.effective_flex_active(block_states)? && !self.flex_timewiggle_active() {
let kern = SurvivalMarginalSlopeRowKernel::new(self.clone(), block_states.to_vec());
let rows = crate::row_kernel::row_set_from_options(options, self.n);
return Ok(Some(Arc::new(RowKernelHessianWorkspace::with_rows(
kern, rows,
)?)));
}
Ok(Some(Arc::new(
SurvivalMarginalSlopeExactNewtonJointHessianWorkspace::new(
self.clone(),
block_states.to_vec(),
options.clone(),
)?,
)))
}
fn inner_coefficient_hessian_hvp_available(&self, specs: &[ParameterBlockSpec]) -> bool {
!self.per_z_logslope_active() && parameter_block_specs_match_rows(specs, self.n)
}
fn outer_hyper_hessian_hvp_available(&self, specs: &[ParameterBlockSpec]) -> bool {
!self.per_z_logslope_active() && parameter_block_specs_match_rows(specs, self.n)
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if self.per_z_logslope_active() {
return Ok(None);
}
if self.effective_flex_active(block_states)? {
return if self.flex_timewiggle_active() {
self.exact_newton_joint_hessian_directional_derivative_timewiggle_flex(
block_states,
d_beta_flat,
)
} else {
self.exact_newton_joint_hessian_directional_derivative_flex_no_wiggle(
block_states,
d_beta_flat,
)
}
.map(Some);
}
if self.flex_timewiggle_active() {
return self
.exact_newton_joint_hessian_directional_derivative_timewiggle(
block_states,
d_beta_flat,
)
.map(Some);
}
let kern = SurvivalMarginalSlopeRowKernel::new(self.clone(), block_states.to_vec());
let sl = d_beta_flat.as_slice().ok_or("non-contiguous d_beta")?;
crate::row_kernel::row_kernel_directional_derivative(
&kern,
&crate::row_kernel::RowSet::All,
sl,
)
.map(Some)
}
fn exact_newton_joint_hessiansecond_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if self.per_z_logslope_active() {
return Ok(None);
}
if self.effective_flex_active(block_states)? {
if self.flex_timewiggle_active() {
return self
.exact_newton_joint_hessiansecond_directional_derivative_timewiggle(
block_states,
d_beta_u_flat,
d_beta_v_flat,
)
.map(Some);
}
return self
.exact_newton_joint_hessiansecond_directional_derivative_flex_no_wiggle(
block_states,
d_beta_u_flat,
d_beta_v_flat,
)
.map(Some);
}
if self.flex_timewiggle_active() {
return self
.exact_newton_joint_hessiansecond_directional_derivative_timewiggle(
block_states,
d_beta_u_flat,
d_beta_v_flat,
)
.map(Some);
}
let kern = SurvivalMarginalSlopeRowKernel::new(self.clone(), block_states.to_vec());
let su = d_beta_u_flat.as_slice().ok_or("non-contiguous d_beta_u")?;
let sv = d_beta_v_flat.as_slice().ok_or("non-contiguous d_beta_v")?;
crate::row_kernel::row_kernel_second_directional_derivative(
&kern,
&crate::row_kernel::RowSet::All,
su,
sv,
)
.map(Some)
}
fn joint_jeffreys_information_directional_derivative_all_axes_with_specs(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
) -> Result<Option<Vec<Array2<f64>>>, String> {
if !self.outer_default_trustworthy_for_joint_hessian(specs)
&& !self.joint_hessian_is_structurally_coupled(block_states)?
{
return Ok(None);
}
if !self.per_z_logslope_active()
&& !self.effective_flex_active(block_states)?
&& !self.flex_timewiggle_active()
{
let kern = SurvivalMarginalSlopeRowKernel::new(self.clone(), block_states.to_vec());
let axes = crate::row_kernel::row_kernel_directional_derivative_all_axes(
&kern,
&crate::row_kernel::RowSet::All,
)?;
return Ok(Some(axes));
}
if !self.per_z_logslope_active()
&& self.effective_flex_active(block_states)?
&& !self.flex_timewiggle_active()
{
let axes = self
.exact_newton_joint_hessian_directional_derivative_flex_no_wiggle_all_axes(
block_states,
)?;
return Ok(Some(axes));
}
let p = specs.iter().map(|spec| spec.design.ncols()).sum::<usize>();
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let results: Vec<Result<Option<Array2<f64>>, String>> = (0..p)
.into_par_iter()
.map(|axis_idx| {
let mut axis = Array1::<f64>::zeros(p);
axis[axis_idx] = 1.0;
gam_problem::with_nested_parallel(|| {
self.exact_newton_joint_hessian_directional_derivative(block_states, &axis)
})
})
.collect();
let mut axes = Vec::with_capacity(p);
for result in results {
match result? {
Some(matrix) => axes.push(matrix),
None => return Ok(None),
}
}
Ok(Some(axes))
}
fn joint_jeffreys_information_second_directional_all_axes_with_specs(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
d_beta_u_flat: &Array1<f64>,
) -> Result<Option<Vec<Array2<f64>>>, String> {
if !self.outer_default_trustworthy_for_joint_hessian(specs)
&& !self.joint_hessian_is_structurally_coupled(block_states)?
{
return Ok(None);
}
if !self.per_z_logslope_active()
&& !self.effective_flex_active(block_states)?
&& !self.flex_timewiggle_active()
{
let kern = SurvivalMarginalSlopeRowKernel::new(self.clone(), block_states.to_vec());
let su = d_beta_u_flat.as_slice().ok_or("non-contiguous d_beta_u")?;
let axes =
crate::row_kernel::row_kernel_second_directional_derivative_all_axes(
&kern,
&crate::row_kernel::RowSet::All,
su,
)?;
return Ok(Some(axes));
}
let p = specs.iter().map(|spec| spec.design.ncols()).sum::<usize>();
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let results: Vec<Result<Option<Array2<f64>>, String>> = (0..p)
.into_par_iter()
.map(|axis_idx| {
let mut axis = Array1::<f64>::zeros(p);
axis[axis_idx] = 1.0;
gam_problem::with_nested_parallel(|| {
self.exact_newton_joint_hessiansecond_directional_derivative(
block_states,
d_beta_u_flat,
&axis,
)
})
})
.collect();
let mut axes = Vec::with_capacity(p);
for result in results {
match result? {
Some(matrix) => axes.push(matrix),
None => return Ok(None),
}
}
Ok(Some(axes))
}
fn exact_newton_joint_psi_terms(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_index) {
return self.sigma_exact_joint_psi_terms(block_states, specs);
}
self.psi_terms(block_states, derivative_blocks, psi_index)
}
fn exact_newton_joint_psisecond_order_terms(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_i)
|| self.is_sigma_aux_index(derivative_blocks, psi_j)
{
if psi_i == psi_j {
return self.sigma_exact_joint_psisecond_order_terms(block_states);
}
return Ok(None);
}
self.psi_second_order_terms(block_states, derivative_blocks, psi_i, psi_j)
}
fn exact_newton_joint_psihessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
_: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
if self.is_sigma_aux_index(derivative_blocks, psi_index) {
return self
.sigma_exact_joint_psihessian_directional_derivative(block_states, d_beta_flat);
}
self.psi_hessian_directional_derivative(
block_states,
derivative_blocks,
psi_index,
d_beta_flat,
)
}
fn exact_newton_joint_psi_workspace(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
if self.per_z_logslope_active() {
return Ok(None);
}
Ok(Some(Arc::new(
crate::marginal_slope_shared::MarginalSlopeExactNewtonPsiWorkspace::new(
SurvivalMarginalSlopePsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
BlockwiseFitOptions::default(),
)?,
),
)))
}
fn exact_newton_joint_psi_workspace_with_options(
&self,
block_states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
derivative_blocks: &[Vec<crate::custom_family::CustomFamilyBlockPsiDerivative>],
options: &BlockwiseFitOptions,
) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
let owned;
let options: &BlockwiseFitOptions = match self.install_auto_outer_subsample_options(options)
{
Some(cloned) => {
owned = cloned;
&owned
}
None => options,
};
Ok(Some(Arc::new(
crate::marginal_slope_shared::MarginalSlopeExactNewtonPsiWorkspace::new(
SurvivalMarginalSlopePsiWorkspace::new(
self.clone(),
block_states.to_vec(),
specs.to_vec(),
derivative_blocks.to_vec(),
options.clone(),
)?,
),
)))
}
fn block_linear_constraints(
&self,
_: &[ParameterBlockState],
block_idx: usize,
block_spec: &ParameterBlockSpec,
) -> Result<Option<LinearInequalityConstraints>, String> {
assert!(!block_spec.name.is_empty());
if block_idx == 0 {
return self.effective_time_linear_constraints();
}
if self.score_warp.is_some() && block_idx == 3 {
return self
.score_warp
.as_ref()
.map(|runtime| self.score_warp_linear_constraints(runtime))
.transpose();
}
let link_block_idx = if self.score_warp.is_some() { 4 } else { 3 };
if self.link_dev.is_some() && block_idx == link_block_idx {
return Ok(self
.link_dev
.as_ref()
.map(DeviationRuntime::structural_monotonicity_constraints));
}
Ok(None)
}
fn max_feasible_step_size(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
delta: &Array1<f64>,
) -> Result<Option<f64>, String> {
if block_idx == 0 {
return self.max_feasible_time_step(&block_states[0].beta, delta);
}
Ok(None)
}
fn post_update_block_beta(
&self,
block_states: &[ParameterBlockState],
block_idx: usize,
block_spec: &ParameterBlockSpec,
beta: Array1<f64>,
) -> Result<Array1<f64>, String> {
assert!(!block_spec.name.is_empty());
if block_idx >= block_states.len() {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"post-update block index {} out of range for {} blocks",
block_idx,
block_states.len()
),
}
.into());
}
if block_idx == 0 {
let current = &block_states[0].beta;
self.validate_time_qd1_feasible(current, "current")?;
self.validate_time_qd1_feasible(&beta, "proposed")?;
return Ok(beta);
}
if self.score_warp.is_some()
&& block_idx == 3
&& let Some(runtime) = &self.score_warp
{
let current = &block_states[3].beta;
if current.len() != beta.len() {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival score-warp post-update beta length mismatch: current={}, proposed={}",
current.len(),
beta.len()
),
}
.into());
}
let expected = runtime.basis_dim() * self.score_dim();
if beta.len() != expected {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival score-warp post-update beta length mismatch: proposed={}, expected {expected} for K={} and basis dim {}",
beta.len(),
self.score_dim(),
runtime.basis_dim()
),
}
.into());
}
for coord in 0..self.score_dim() {
let range = score_warp_component_range(runtime, coord);
let proposed_local = beta.slice(s![range.clone()]).to_owned();
runtime
.monotonicity_feasible(&proposed_local, &format!("score_warp_dev[z{coord}]"))?;
}
return Ok(beta);
}
let link_block_idx = if self.score_warp.is_some() { 4 } else { 3 };
if self.link_dev.is_some()
&& block_idx == link_block_idx
&& let Some(runtime) = &self.link_dev
{
let current = block_states
.get(link_block_idx)
.map(|state| &state.beta)
.ok_or_else(|| "missing survival link-deviation block state".to_string())?;
if current.len() != beta.len() {
return Err(SurvivalMarginalSlopeError::IncompatibleDimensions {
reason: format!(
"survival link-deviation post-update beta length mismatch: current={}, proposed={}",
current.len(),
beta.len()
),
}
.into());
}
runtime.monotonicity_feasible(current, "link_dev current")?;
runtime.monotonicity_feasible(&beta, "link_dev proposed")?;
return Ok(beta);
}
Ok(beta)
}
}