use super::*;
#[derive(Clone)]
pub(crate) struct SurvivalMarginalSlopeFamily {
pub(crate) n: usize,
pub(crate) event: Arc<Array1<f64>>,
pub(crate) weights: Arc<Array1<f64>>,
pub(crate) z: Arc<Array2<f64>>,
pub(crate) score_covariance: MarginalSlopeCovariance,
pub(crate) gaussian_frailty_sd: Option<f64>,
pub(crate) derivative_guard: f64,
pub(crate) design_entry: DesignMatrix,
pub(crate) design_exit: DesignMatrix,
pub(crate) design_derivative_exit: DesignMatrix,
pub(crate) offset_entry: Arc<Array1<f64>>,
pub(crate) offset_exit: Arc<Array1<f64>>,
pub(crate) derivative_offset_exit: Arc<Array1<f64>>,
pub(crate) marginal_design: DesignMatrix,
pub(crate) logslope_design: DesignMatrix,
pub(crate) logslope_surface_ranges: Vec<std::ops::Range<usize>>,
pub(crate) score_warp: Option<DeviationRuntime>,
pub(crate) link_dev: Option<DeviationRuntime>,
pub(crate) influence_absorber: Option<Array2<f64>>,
pub(crate) time_linear_constraints: Option<LinearInequalityConstraints>,
pub(crate) time_wiggle_knots: Option<Array1<f64>>,
pub(crate) time_wiggle_degree: Option<usize>,
pub(crate) time_wiggle_ncols: usize,
pub(crate) intercept_warm_starts: Option<Arc<SurvivalInterceptWarmStartCache>>,
pub(crate) auto_subsample_phase_counter: Arc<AtomicUsize>,
pub(crate) auto_subsample_last_rho: Arc<Mutex<Option<Array1<f64>>>>,
}
pub(crate) const SURVIVAL_MGS_AUTO_SUBSAMPLE_PHASE1_BUDGET: usize = 12;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum SurvivalInterceptSlotKind {
Entry = 0,
Exit = 1,
}
pub(crate) struct SurvivalInterceptWarmStartCache {
pub(crate) entry_value: Vec<std::sync::atomic::AtomicU64>,
pub(crate) entry_tag: Vec<std::sync::atomic::AtomicU64>,
pub(crate) exit_value: Vec<std::sync::atomic::AtomicU64>,
pub(crate) exit_tag: Vec<std::sync::atomic::AtomicU64>,
}
impl SurvivalInterceptWarmStartCache {
#[inline]
pub(crate) fn slots_for(
&self,
kind: SurvivalInterceptSlotKind,
) -> (
&[std::sync::atomic::AtomicU64],
&[std::sync::atomic::AtomicU64],
) {
match kind {
SurvivalInterceptSlotKind::Entry => (&self.entry_value, &self.entry_tag),
SurvivalInterceptSlotKind::Exit => (&self.exit_value, &self.exit_tag),
}
}
#[inline]
pub(crate) fn load(
&self,
row: usize,
kind: SurvivalInterceptSlotKind,
beta_tag: u64,
) -> Option<f64> {
let (values, tags) = self.slots_for(kind);
let value_slot = values.get(row)?;
let tag_slot = tags.get(row)?;
let tag_before = tag_slot.load(std::sync::atomic::Ordering::Acquire);
if tag_before != beta_tag {
return None;
}
let bits = value_slot.load(std::sync::atomic::Ordering::Relaxed);
let tag_after = tag_slot.load(std::sync::atomic::Ordering::Acquire);
if tag_after != beta_tag {
return None;
}
let value = f64::from_bits(bits);
value.is_finite().then_some(value)
}
#[inline]
pub(crate) fn store(&self, row: usize, kind: SurvivalInterceptSlotKind, a: f64, beta_tag: u64) {
let (values, tags) = self.slots_for(kind);
if let (Some(value_slot), Some(tag_slot)) = (values.get(row), tags.get(row)) {
tag_slot.store(0, std::sync::atomic::Ordering::Release);
value_slot.store(a.to_bits(), std::sync::atomic::Ordering::Relaxed);
tag_slot.store(beta_tag, std::sync::atomic::Ordering::Release);
}
}
}
pub(crate) fn new_intercept_warm_start_cache(n: usize) -> Arc<SurvivalInterceptWarmStartCache> {
Arc::new(SurvivalInterceptWarmStartCache {
entry_value: (0..n)
.map(|_| std::sync::atomic::AtomicU64::new(f64::NAN.to_bits()))
.collect(),
entry_tag: (0..n)
.map(|_| std::sync::atomic::AtomicU64::new(0))
.collect(),
exit_value: (0..n)
.map(|_| std::sync::atomic::AtomicU64::new(f64::NAN.to_bits()))
.collect(),
exit_tag: (0..n)
.map(|_| std::sync::atomic::AtomicU64::new(0))
.collect(),
})
}
#[inline]
pub(crate) fn hash_intercept_warm_start_key(
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> u64 {
let mut hash = Fnv1a::new();
hash.mix_opt_beta(0xa1, beta_h);
hash.mix_opt_beta(0xa2, beta_w);
hash.finish_nonzero()
}
#[derive(Clone, Default)]
pub(crate) struct ThetaHints {
pub(crate) time_beta: Option<Array1<f64>>,
pub(crate) marginal_beta: Option<Array1<f64>>,
pub(crate) logslope_beta: Option<Array1<f64>>,
pub(crate) score_warp_beta: Option<Array1<f64>>,
pub(crate) link_dev_beta: Option<Array1<f64>>,
pub(crate) influence_beta: Option<Array1<f64>>,
}
impl SurvivalMarginalSlopeFamily {
pub(crate) fn time_derivative_lower_bound(&self) -> f64 {
assert!(
self.derivative_guard.is_finite() && self.derivative_guard > 0.0,
"survival marginal-slope derivative guard must be finite and positive: derivative_guard={}",
self.derivative_guard
);
self.derivative_guard
}
pub(crate) fn flex_active(&self) -> bool {
self.score_warp.is_some() || self.link_dev.is_some() || self.influence_absorber.is_some()
}
pub(crate) fn effective_flex_active(
&self,
block_states: &[ParameterBlockState],
) -> Result<bool, String> {
if self.score_warp.is_some() && self.flex_score_beta(block_states)?.is_none() {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "missing survival score-warp block state".to_string(),
}
.into());
}
if self.link_dev.is_some() && self.flex_link_beta(block_states)?.is_none() {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "missing survival link-deviation block state".to_string(),
}
.into());
}
if self.influence_absorber.is_some() && self.flex_influence_beta(block_states)?.is_none() {
return Err(SurvivalMarginalSlopeError::InvalidInput {
reason: "missing survival influence-absorber block state".to_string(),
}
.into());
}
Ok(self.flex_active())
}
pub(crate) fn flex_score_beta<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a Array1<f64>>, String> {
if self.score_warp.is_none() {
return Ok(None);
}
block_states
.get(3)
.map(|state| Some(&state.beta))
.ok_or_else(|| "missing survival score-warp block state".to_string())
}
pub(crate) fn flex_link_beta<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a Array1<f64>>, String> {
if self.link_dev.is_none() {
return Ok(None);
}
let idx = if self.score_warp.is_some() { 4 } else { 3 };
block_states
.get(idx)
.map(|state| Some(&state.beta))
.ok_or_else(|| "missing survival link-deviation block state".to_string())
}
pub(crate) fn flex_influence_beta<'a>(
&self,
block_states: &'a [ParameterBlockState],
) -> Result<Option<&'a Array1<f64>>, String> {
if self.influence_absorber.is_none() {
return Ok(None);
}
let idx = 3 + usize::from(self.score_warp.is_some()) + usize::from(self.link_dev.is_some());
block_states
.get(idx)
.map(|state| Some(&state.beta))
.ok_or_else(|| "missing survival influence-absorber block state".to_string())
}
pub(crate) fn influence_index_offset(
&self,
row: usize,
block_states: &[ParameterBlockState],
) -> Result<f64, String> {
let (Some(z_tilde), Some(gamma)) = (
self.influence_absorber.as_ref(),
self.flex_influence_beta(block_states)?,
) else {
return Ok(0.0);
};
if gamma.len() != z_tilde.ncols() {
return Err(format!(
"survival influence-absorber β length {} != Z̃_infl columns {}",
gamma.len(),
z_tilde.ncols()
));
}
Ok(z_tilde.row(row).dot(gamma))
}
}