use super::*;
const SURVIVAL_EXP_NEG_STABLE_MAX_ARG: f64 = 500.0;
#[inline]
fn survival_inverse_sigma_from_eta_log_sigma(eta_log_sigma: f64) -> f64 {
(-eta_log_sigma).min(SURVIVAL_EXP_NEG_STABLE_MAX_ARG).exp()
}
#[inline]
fn survival_q0_and_inverse_sigma(eta_threshold: f64, eta_log_sigma: f64) -> (f64, f64) {
let inv_sigma = survival_inverse_sigma_from_eta_log_sigma(eta_log_sigma);
if eta_threshold == 0.0 {
return (0.0, inv_sigma);
}
let log_abs = eta_threshold.abs().ln() + (-eta_log_sigma).min(SURVIVAL_EXP_NEG_STABLE_MAX_ARG);
let q0 = if log_abs > SURVIVAL_EXP_NEG_STABLE_MAX_ARG {
if eta_threshold > 0.0 {
-f64::MAX
} else {
f64::MAX
}
} else {
-eta_threshold * inv_sigma
};
(q0, inv_sigma)
}
#[inline]
fn survival_tail_value_from_failure_jet(
inverse_link: &InverseLink,
eta: f64,
failure_jet: &InverseLinkJet,
) -> f64 {
match inverse_link {
InverseLink::Standard(crate::types::StandardLink::Probit) => {
if eta.is_nan() {
f64::NAN
} else if eta == f64::INFINITY {
0.0
} else if eta == f64::NEG_INFINITY {
1.0
} else {
0.5 * statrs::function::erf::erfc(eta / std::f64::consts::SQRT_2)
}
}
InverseLink::Standard(crate::types::StandardLink::Logit) => 1.0 / (1.0 + eta.exp()),
InverseLink::Standard(crate::types::StandardLink::CLogLog) => (-(eta.exp())).exp(),
_ => (1.0 - failure_jet.mu).clamp(0.0, 1.0),
}
}
#[inline]
fn inverse_link_survival_tail_value_and_failure_density(
inverse_link: &InverseLink,
eta: f64,
) -> Result<(f64, f64), EstimationError> {
let failure_jet =
crate::solver::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta)?;
Ok((
survival_tail_value_from_failure_jet(inverse_link, eta, &failure_jet).clamp(0.0, 1.0),
failure_jet.d1,
))
}
pub struct SurvivalPredictor {
pub beta_threshold: Array1<f64>,
pub beta_log_sigma: Array1<f64>,
pub covariance: Option<Array2<f64>>,
pub inverse_link: InverseLink,
}
impl SurvivalPredictor {
pub(crate) fn from_unified(
unified: &UnifiedFitResult,
inverse_link: InverseLink,
) -> Result<Self, EstimationError> {
let beta_threshold = unified
.block_by_role(BlockRole::Threshold)
.or_else(|| unified.block_by_role(BlockRole::Location))
.or_else(|| unified.block_by_role(BlockRole::Mean))
.map(|b| b.beta.clone())
.ok_or_else(|| {
EstimationError::InvalidInput("Survival model missing threshold block".to_string())
})?;
let beta_log_sigma = unified
.block_by_role(BlockRole::Scale)
.map(|b| b.beta.clone())
.ok_or_else(|| {
EstimationError::InvalidInput(
"Survival model missing scale (log-sigma) block".to_string(),
)
})?;
Ok(Self {
beta_threshold,
beta_log_sigma,
covariance: unified.covariance_conditional.clone(),
inverse_link,
})
}
fn compute_survival(
&self,
eta_threshold: &Array1<f64>,
eta_log_sigma: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let n = eta_threshold.len();
let survival_prob: Result<Vec<f64>, EstimationError> = (0..n)
.into_par_iter()
.map(|i| {
let (q0, _) = survival_q0_and_inverse_sigma(eta_threshold[i], eta_log_sigma[i]);
let (survival, _) =
inverse_link_survival_tail_value_and_failure_density(&self.inverse_link, q0)?;
Ok(survival)
})
.collect();
Ok(Array1::from_vec(survival_prob?))
}
}
impl SurvivalPredictor {
fn linear_predictors<'a>(
&self,
input: &'a PredictInput,
) -> Result<(Array1<f64>, Array1<f64>, &'a DesignMatrix), EstimationError> {
let eta_threshold = input.design.dot(&self.beta_threshold) + &input.offset;
let design_noise = input.design_noise.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"Survival prediction requires noise (log-sigma) design matrix".to_string(),
)
})?;
let offset_noise = input.offset_noise.as_ref().ok_or_else(|| {
EstimationError::InvalidInput(
"Survival prediction requires noise (log-sigma) offset".to_string(),
)
})?;
let eta_log_sigma = design_noise.dot(&self.beta_log_sigma) + offset_noise;
Ok((eta_threshold, eta_log_sigma, design_noise))
}
fn state_from_backend(
&self,
input: &PredictInput,
backend: &PredictionCovarianceBackend<'_>,
) -> Result<LinearState, EstimationError> {
let (eta_threshold, eta_log_sigma, design_noise) = self.linear_predictors(input)?;
let survival_prob = self.compute_survival(&eta_threshold, &eta_log_sigma)?;
let n = eta_threshold.len();
let p_t = self.beta_threshold.len();
let p_s = self.beta_log_sigma.len();
let eta_se = padded_design_standard_errors_from_backend(
&input.design,
backend,
0,
p_s,
"survival threshold uncertainty",
)?;
let mean_se_vec = linear_predictor_se_from_backend(backend, n, |rows| {
let x_t = design_row_chunk(&input.design, rows.clone())?;
let x_s = design_row_chunk(design_noise, rows.clone())?;
let eta_t_chunk = eta_threshold.slice(ndarray::s![rows.clone()]);
let eta_ls_chunk = eta_log_sigma.slice(ndarray::s![rows.clone()]);
let rows_in_chunk = eta_t_chunk.len();
let mut grad = Array2::<f64>::zeros((rows_in_chunk, p_t + p_s));
for i in 0..rows_in_chunk {
let (q0, inv_sigma) =
survival_q0_and_inverse_sigma(eta_t_chunk[i], eta_ls_chunk[i]);
let (_, failure_density) =
inverse_link_survival_tail_value_and_failure_density(&self.inverse_link, q0)
.map_err(|e| e.to_string())?;
let dsurv_deta_t = failure_density * inv_sigma;
let dsurv_deta_s = failure_density * q0;
for j in 0..p_t {
grad[[i, j]] = dsurv_deta_t * x_t[[i, j]];
}
for j in 0..p_s {
grad[[i, p_t + j]] = dsurv_deta_s * x_s[[i, j]];
}
}
Ok(vec![grad])
})?;
Ok(LinearState {
eta: eta_threshold,
mean: survival_prob,
eta_se: Some(eta_se),
mean_se: Some(mean_se_vec),
covariance_corrected_used: false,
})
}
fn plugin_state_from_covariance(
&self,
input: &PredictInput,
) -> Result<LinearState, EstimationError> {
if let Some(ref cov) = self.covariance {
let backend = PredictionCovarianceBackend::from_dense(cov.view());
self.state_from_backend(input, &backend)
} else {
let (eta_threshold, eta_log_sigma, _) = self.linear_predictors(input)?;
let survival_prob = self.compute_survival(&eta_threshold, &eta_log_sigma)?;
Ok(LinearState {
eta: eta_threshold,
mean: survival_prob,
eta_se: None,
mean_se: None,
covariance_corrected_used: false,
})
}
}
}
impl PredictionTransform for SurvivalPredictor {
fn point_state(&self, input: &PredictInput) -> Result<LinearState, EstimationError> {
self.plugin_state_from_covariance(input)
}
fn linear_state(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
pass: PredictPass,
covariance_mode: InferenceCovarianceMode,
) -> Result<LinearState, EstimationError> {
match pass {
PredictPass::FullUncertainty => {
let p_total = self.beta_threshold.len() + self.beta_log_sigma.len();
let (backend, covariance_corrected_used) =
fit.select_uncertainty_backend(p_total, covariance_mode, "survival")?;
let mut state = self.state_from_backend(input, &backend)?;
state.covariance_corrected_used = covariance_corrected_used;
Ok(state)
}
PredictPass::PosteriorMean => {
assert!(std::mem::size_of_val(&covariance_mode) > 0);
let (eta_threshold, eta_log_sigma, design_noise) = self.linear_predictors(input)?;
let p_t = self.beta_threshold.len();
let p_s = self.beta_log_sigma.len();
let p_total = p_t + p_s;
let backend = require_posterior_mean_backend(
fit,
self.covariance.as_ref(),
p_total,
"survival posterior mean",
)?;
let eta_se = padded_design_standard_errors_from_backend(
&input.design,
&backend,
0,
p_s,
"survival posterior mean",
)?;
let (var_t, var_s, cov_ts) = project_two_block_linear_predictor_covariance(
&input.design,
design_noise,
&backend,
p_t,
p_s,
"survival posterior mean",
)?;
let quadctx = crate::quadrature::QuadratureContext::new();
let mean = Array1::from_vec(
(0..eta_threshold.len())
.map(|i| {
projected_bivariate_posterior_mean_result(
&quadctx,
[eta_threshold[i], eta_log_sigma[i]],
[
[var_t[i].max(0.0), cov_ts[i]],
[cov_ts[i], var_s[i].max(0.0)],
],
|threshold, log_sigma| {
let (q0, _) =
survival_q0_and_inverse_sigma(threshold, log_sigma);
let (survival, _) =
inverse_link_survival_tail_value_and_failure_density(
&self.inverse_link,
q0,
)?;
Ok(survival)
},
)
})
.collect::<Result<Vec<_>, _>>()?,
);
Ok(LinearState {
eta: eta_threshold,
mean,
eta_se: Some(eta_se),
mean_se: None,
covariance_corrected_used: false,
})
}
}
}
fn response(&self, eta: &Array1<f64>) -> Result<Array1<f64>, EstimationError> {
self.compute_survival(eta, &Array1::zeros(eta.len()))
}
fn response_jacobian_rows(&self, pass: PredictPass) -> ResponseInterval {
match pass {
PredictPass::FullUncertainty => ResponseInterval::SymmetricDelta,
PredictPass::PosteriorMean => ResponseInterval::CollapsedDelta,
}
}
fn bounds(&self) -> ResponseBounds {
ResponseBounds::UNIT_PROBABILITY
}
fn response_family(&self) -> ResponseFamily {
ResponseFamily::RoystonParmar
}
}
impl PredictableModel for SurvivalPredictor {
fn predict_plugin_response(
&self,
input: &PredictInput,
) -> Result<PredictResult, EstimationError> {
predict_plugin_response_generic(self, input)
}
fn predict_with_uncertainty(
&self,
input: &PredictInput,
) -> Result<PredictionWithSE, EstimationError> {
predict_with_uncertainty_generic(self, input)
}
fn predict_noise_scale(
&self,
predict_input: &PredictInput,
) -> Result<Option<Array1<f64>>, EstimationError> {
assert!(std::mem::size_of_val(predict_input) > 0);
Ok(None)
}
fn predict_full_uncertainty(
&self,
input: &PredictInput,
fit_result: &UnifiedFitResult,
options: &PredictUncertaintyOptions,
) -> Result<PredictUncertaintyResult, EstimationError> {
predict_full_uncertainty_generic(self, input, fit_result, options)
}
fn predict_posterior_mean(
&self,
input: &PredictInput,
fit: &UnifiedFitResult,
options: &PosteriorMeanOptions,
) -> Result<PredictPosteriorMeanResult, EstimationError> {
predict_posterior_mean_generic(self, input, fit, options)
}
fn n_blocks(&self) -> usize {
2
}
fn block_roles(&self) -> Vec<BlockRole> {
vec![BlockRole::Threshold, BlockRole::Scale]
}
}