use super::*;
pub(crate) fn survival_inverse_link_has_free_parameters(link: &InverseLink) -> bool {
match link {
InverseLink::Sas(_) | InverseLink::BetaLogistic(_) => true,
InverseLink::Mixture(state) => !state.rho.is_empty(),
InverseLink::LatentCLogLog(_) | InverseLink::Standard(_) => false,
}
}
pub(crate) fn recover_converged_survival_inverse_link<R>(
result: crate::solver::rho_optimizer::OuterResult,
context: &str,
recover: R,
) -> Result<InverseLink, String>
where
R: FnOnce(&Array1<f64>) -> Option<InverseLink>,
{
if !result.converged {
return Err(WorkflowError::IntegrationFailed {
reason: format!(
"{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
result.iterations,
result.final_value,
result.final_grad_norm_report(),
),
}
.into());
}
recover(&result.rho).ok_or_else(|| {
format!(
"{context} produced an invalid inverse-link state at rho={:?}",
result.rho.to_vec()
)
})
}
const LOG_LAMBDA_UNDERFLOW_FLOOR: f64 = 1e-300;
const SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS: usize = 400;
const SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL: f64 = 1e-6;
const SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING: usize = 40;
const SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
const SURVIVAL_TRANSFORMATION_NONCONVERGED_TRIAL_COST: f64 = 1e12;
struct SurvivalLocationScaleProfile {
fit: SurvivalLocationScaleTermFitResult,
inverse_link: InverseLink,
wiggle_knots: Option<Array1<f64>>,
wiggle_degree: Option<usize>,
}
impl SurvivalLocationScaleProfile {
fn into_result(self) -> SurvivalLocationScaleFitResult {
SurvivalLocationScaleFitResult {
fit: self.fit,
inverse_link: self.inverse_link,
wiggle_knots: self.wiggle_knots,
wiggle_degree: self.wiggle_degree,
}
}
}
fn resolved_wiggle_inverse_link(
spec: &LikelihoodSpec,
fit: &UnifiedFitResult,
fallback: &InverseLink,
) -> Result<InverseLink, String> {
let resolved = match fit.fitted_link_state(spec).map_err(|e| e.to_string())? {
FittedLinkState::Standard(Some(link)) => InverseLink::Standard(link),
FittedLinkState::Standard(None) => fallback.clone(),
FittedLinkState::LatentCLogLog { state } => InverseLink::LatentCLogLog(state),
FittedLinkState::Sas { state, .. } => InverseLink::Sas(state),
FittedLinkState::BetaLogistic { state, .. } => InverseLink::BetaLogistic(state),
FittedLinkState::Mixture { state, .. } => InverseLink::Mixture(state),
};
require_inverse_link_supports_joint_wiggle(&resolved, "standard link wiggle")?;
Ok(resolved)
}
fn deviation_block_config_from_formula_linkwiggle(
wiggle: &LinkWiggleFormulaSpec,
) -> DeviationBlockConfig {
let defaults = WigglePenaltyConfig::cubic_triple_operator_default();
DeviationBlockConfig {
degree: wiggle.degree,
num_internal_knots: wiggle.num_internal_knots,
penalty_order: *wiggle.penalty_orders.iter().max().unwrap_or(&2),
penalty_orders: wiggle.penalty_orders.clone(),
double_penalty: wiggle.double_penalty,
monotonicity_eps: defaults.monotonicity_eps,
}
}
pub(crate) struct MarginalSlopeDeviationRouting {
pub(crate) score_warp: Option<DeviationBlockConfig>,
pub(crate) link_dev: Option<DeviationBlockConfig>,
}
pub(crate) fn route_marginal_slope_deviation_blocks(
main_linkwiggle: Option<&LinkWiggleFormulaSpec>,
logslope_linkwiggle: Option<&LinkWiggleFormulaSpec>,
) -> Result<MarginalSlopeDeviationRouting, String> {
Ok(MarginalSlopeDeviationRouting {
score_warp: logslope_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
link_dev: main_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
})
}
pub(crate) fn fixed_gaussian_shift_frailty_from_spec(
frailty: &FrailtySpec,
context: &str,
) -> Result<FrailtySpec, String> {
match frailty {
FrailtySpec::None => Ok(FrailtySpec::None),
FrailtySpec::GaussianShift {
sigma_fixed: Some(sigma),
} => Ok(FrailtySpec::GaussianShift {
sigma_fixed: Some(*sigma),
}),
FrailtySpec::GaussianShift { sigma_fixed: None } => Err(WorkflowError::MissingDependency {
reason: format!("{context} currently requires a fixed GaussianShift sigma"),
}
.into()),
FrailtySpec::HazardMultiplier { .. } => Err(WorkflowError::MissingDependency {
reason: format!("{context} requires FrailtySpec::GaussianShift or no frailty"),
}
.into()),
}
}
pub(crate) fn fit_standard_model(
request: StandardFitRequest<'_>,
) -> Result<StandardFitResult, String> {
let fitted = if let Some(latent_coord) = request.latent_coord.as_ref() {
if !request.coefficient_groups.is_empty() || !request.penalty_block_gamma_priors.is_empty()
{
return Err("latent-coordinate standard fits do not support coefficient_groups or penalty_block_gamma_priors in the same request".to_string());
}
fit_term_collectionwith_latent_coord_optimization(
request.data.view(),
request.y.clone(),
request.weights.clone(),
request.offset.clone(),
&request.spec,
latent_coord,
request.family.clone(),
&request.options,
)
.map_err(|e| e.to_string())?
} else if !request.coefficient_groups.is_empty()
|| !request.penalty_block_gamma_priors.is_empty()
{
let fitted = fit_term_collection_with_coefficient_groups_and_penalty_block_gamma_priors(
request.data.view(),
request.y.view(),
request.weights.view(),
request.offset.view(),
&request.spec,
&request.coefficient_groups,
&request.penalty_block_gamma_priors,
request.family.clone(),
&request.options,
)
.map_err(|e| e.to_string())?;
let resolvedspec =
crate::smooth::freeze_term_collection_from_design(&request.spec, &fitted.design)
.map_err(|e| e.to_string())?;
crate::terms::smooth::FittedTermCollectionWithSpec {
fit: fitted.fit,
design: fitted.design,
resolvedspec,
adaptive_diagnostics: fitted.adaptive_diagnostics,
kappa_timing: None,
}
} else {
fit_term_collectionwith_spatial_length_scale_optimization(
request.data.view(),
request.y.clone(),
request.weights.clone(),
request.offset.clone(),
&request.spec,
request.family.clone(),
&request.options,
&request.kappa_options,
)
.map_err(|e| e.to_string())?
};
let result = StandardFitResult {
saved_link_state: fitted.fit.fitted_link.clone(),
fit: fitted.fit,
design: fitted.design,
resolvedspec: fitted.resolvedspec,
adaptive_diagnostics: fitted.adaptive_diagnostics,
kappa_timing: fitted.kappa_timing,
wiggle_knots: None,
wiggle_degree: None,
};
let Some(wiggle) = request.wiggle else {
return Ok(result);
};
let wiggle_options = wiggle.refit_options.clone();
let wiggle_link_kind =
resolved_wiggle_inverse_link(&request.family, &result.fit, &wiggle.link_kind)?;
let selected_wiggle_basis = select_binomial_mean_link_wiggle_basis_from_pilot(
&result.design,
&result.fit,
&WiggleBlockConfig {
degree: wiggle.wiggle.degree,
num_internal_knots: wiggle.wiggle.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle.wiggle.double_penalty,
},
&wiggle.wiggle.penalty_orders,
)?;
let solved = match fit_binomial_mean_wiggle_terms_with_selected_basis(
request.data.view(),
&result.resolvedspec,
&result.design,
&result.fit,
&request.y,
&request.weights,
wiggle_link_kind,
selected_wiggle_basis,
&wiggle_options,
&request.kappa_options,
) {
Ok(solved) => solved,
Err(e) => {
log::warn!(
"[linkwiggle] binomial mean link-wiggle joint solve did not converge ({e}); \
falling back to the no-wiggle baseline fit (the large-smoothing limit of the \
penalized wiggle model, which contains it as a limiting case)"
);
return Ok(result);
}
};
Ok(StandardFitResult {
saved_link_state: result.saved_link_state,
fit: solved.fit,
design: solved.design,
resolvedspec: solved.resolvedspec,
adaptive_diagnostics: result.adaptive_diagnostics,
kappa_timing: result.kappa_timing,
wiggle_knots: Some(solved.wiggle_knots),
wiggle_degree: Some(solved.wiggle_degree),
})
}
struct LocationScaleWorkflowParts<'a, S> {
data: ArrayView2<'a, f64>,
spec: S,
wiggle: Option<LinkWiggleConfig>,
options: BlockwiseFitOptions,
kappa_options: SpatialLengthScaleOptimizationOptions,
}
trait LocationScaleWorkflowAdapter {
type Spec;
type Request<'a>;
type Result;
fn into_parts<'a>(request: Self::Request<'a>) -> LocationScaleWorkflowParts<'a, Self::Spec>;
fn fit_pilot(
data: ArrayView2<'_, f64>,
spec: &Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String>;
fn refit_with_selected_wiggle(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
pilot: &BlockwiseTermFitResult,
wiggle_cfg: &LinkWiggleConfig,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermWiggleFitResult, String>;
fn fit_plain(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String>;
fn assemble_plain(fit: BlockwiseTermFitResult) -> Self::Result;
fn assemble_with_wiggle(
fit: BlockwiseTermFitResult,
wiggle_knots: Array1<f64>,
wiggle_degree: usize,
beta_link_wiggle: Option<Vec<f64>>,
) -> Self::Result;
}
fn fit_location_scale_with_optional_wiggle<A: LocationScaleWorkflowAdapter>(
request: A::Request<'_>,
) -> Result<A::Result, String> {
let LocationScaleWorkflowParts {
data,
spec,
wiggle,
options,
kappa_options,
} = A::into_parts(request);
let Some(wiggle_cfg) = wiggle else {
let fit = A::fit_plain(data, spec, &options, &kappa_options)?;
return Ok(A::assemble_plain(fit));
};
let pilot = A::fit_pilot(data, &spec, &options, &kappa_options)?;
let solved =
A::refit_with_selected_wiggle(data, spec, &pilot, &wiggle_cfg, &options, &kappa_options)?;
let fit = solved.fit.fit;
let beta_link_wiggle = fit.block_states.get(2).map(|b| b.beta.to_vec());
let assembled_fit = BlockwiseTermFitResult::try_from_parts(BlockwiseTermFitResultParts {
fit,
meanspec_resolved: solved.fit.meanspec_resolved,
noisespec_resolved: solved.fit.noisespec_resolved,
mean_design: solved.fit.mean_design,
noise_design: solved.fit.noise_design,
})?;
Ok(A::assemble_with_wiggle(
assembled_fit,
solved.wiggle_knots,
solved.wiggle_degree,
beta_link_wiggle,
))
}
struct GaussianLocationScaleWorkflow;
impl LocationScaleWorkflowAdapter for GaussianLocationScaleWorkflow {
type Spec = GaussianLocationScaleTermSpec;
type Request<'a> = GaussianLocationScaleFitRequest<'a>;
type Result = GaussianLocationScaleFitResult;
fn into_parts<'a>(request: Self::Request<'a>) -> LocationScaleWorkflowParts<'a, Self::Spec> {
LocationScaleWorkflowParts {
data: request.data,
spec: request.spec,
wiggle: request.wiggle,
options: request.options,
kappa_options: request.kappa_options,
}
}
fn fit_pilot(
data: ArrayView2<'_, f64>,
spec: &Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
fit_gaussian_location_scale_terms(
data,
GaussianLocationScaleTermSpec {
y: spec.y.clone(),
weights: spec.weights.clone(),
meanspec: spec.meanspec.clone(),
log_sigmaspec: spec.log_sigmaspec.clone(),
mean_offset: spec.mean_offset.clone(),
log_sigma_offset: spec.log_sigma_offset.clone(),
},
options,
kappa_options,
)
}
fn refit_with_selected_wiggle(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
pilot: &BlockwiseTermFitResult,
wiggle_cfg: &LinkWiggleConfig,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermWiggleFitResult, String> {
let selected_wiggle_basis = select_gaussian_location_scale_link_wiggle_basis_from_pilot(
pilot,
&WiggleBlockConfig {
degree: wiggle_cfg.degree,
num_internal_knots: wiggle_cfg.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle_cfg.double_penalty,
},
&wiggle_cfg.penalty_orders,
)?;
fit_gaussian_location_scale_terms_with_selected_wiggle(
data,
spec,
selected_wiggle_basis,
options,
kappa_options,
)
}
fn fit_plain(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
fit_gaussian_location_scale_terms(data, spec, options, kappa_options)
}
fn assemble_plain(fit: BlockwiseTermFitResult) -> Self::Result {
GaussianLocationScaleFitResult {
fit,
wiggle_knots: None,
wiggle_degree: None,
beta_link_wiggle: None,
response_scale: 1.0,
}
}
fn assemble_with_wiggle(
fit: BlockwiseTermFitResult,
wiggle_knots: Array1<f64>,
wiggle_degree: usize,
beta_link_wiggle: Option<Vec<f64>>,
) -> Self::Result {
GaussianLocationScaleFitResult {
fit,
wiggle_knots: Some(wiggle_knots),
wiggle_degree: Some(wiggle_degree),
beta_link_wiggle,
response_scale: 1.0,
}
}
}
struct BinomialLocationScaleWorkflow;
impl LocationScaleWorkflowAdapter for BinomialLocationScaleWorkflow {
type Spec = BinomialLocationScaleTermSpec;
type Request<'a> = BinomialLocationScaleFitRequest<'a>;
type Result = BinomialLocationScaleFitResult;
fn into_parts<'a>(request: Self::Request<'a>) -> LocationScaleWorkflowParts<'a, Self::Spec> {
LocationScaleWorkflowParts {
data: request.data,
spec: request.spec,
wiggle: request.wiggle,
options: request.options,
kappa_options: request.kappa_options,
}
}
fn fit_pilot(
data: ArrayView2<'_, f64>,
spec: &Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
require_inverse_link_supports_joint_wiggle(
&spec.link_kind,
"binomial location-scale link wiggle",
)?;
fit_binomial_location_scale_terms(
data,
BinomialLocationScaleTermSpec {
y: spec.y.clone(),
weights: spec.weights.clone(),
link_kind: spec.link_kind.clone(),
thresholdspec: spec.thresholdspec.clone(),
log_sigmaspec: spec.log_sigmaspec.clone(),
threshold_offset: spec.threshold_offset.clone(),
log_sigma_offset: spec.log_sigma_offset.clone(),
},
options,
kappa_options,
)
}
fn refit_with_selected_wiggle(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
pilot: &BlockwiseTermFitResult,
wiggle_cfg: &LinkWiggleConfig,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermWiggleFitResult, String> {
let selected_wiggle_basis = select_binomial_location_scale_link_wiggle_basis_from_pilot(
pilot,
&WiggleBlockConfig {
degree: wiggle_cfg.degree,
num_internal_knots: wiggle_cfg.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle_cfg.double_penalty,
},
&wiggle_cfg.penalty_orders,
)?;
fit_binomial_location_scale_terms_with_selected_wiggle(
data,
spec,
selected_wiggle_basis,
options,
kappa_options,
)
}
fn fit_plain(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
fit_binomial_location_scale_terms(data, spec, options, kappa_options)
}
fn assemble_plain(fit: BlockwiseTermFitResult) -> Self::Result {
BinomialLocationScaleFitResult {
fit,
wiggle_knots: None,
wiggle_degree: None,
beta_link_wiggle: None,
}
}
fn assemble_with_wiggle(
fit: BlockwiseTermFitResult,
wiggle_knots: Array1<f64>,
wiggle_degree: usize,
beta_link_wiggle: Option<Vec<f64>>,
) -> Self::Result {
BinomialLocationScaleFitResult {
fit,
wiggle_knots: Some(wiggle_knots),
wiggle_degree: Some(wiggle_degree),
beta_link_wiggle,
}
}
}
fn gaussian_response_sample_std(v: ArrayView1<'_, f64>) -> f64 {
if v.is_empty() {
return 0.0;
}
let n = v.len() as f64;
let mean = v.iter().copied().sum::<f64>() / n;
let var = v
.iter()
.copied()
.map(|x| {
let d = x - mean;
d * d
})
.sum::<f64>()
/ n.max(1.0);
var.max(0.0).sqrt()
}
fn rescale_gaussian_location_scale_to_raw(
result: &mut GaussianLocationScaleFitResult,
response_scale: f64,
) {
use crate::estimate::BlockRole;
let s = response_scale;
let ln_s = s.ln();
let scale_intercept_range = result.fit.noise_design.intercept_range.clone();
let mut joint_offset = 0usize;
for (block_idx, block) in result.fit.fit.blocks.iter_mut().enumerate() {
let block_len = block.beta.len();
match block.role {
BlockRole::Mean | BlockRole::Location | BlockRole::LinkWiggle => {
block.beta.mapv_inplace(|v| v * s);
if result.fit.fit.beta.len() >= joint_offset + block_len {
for i in 0..block_len {
result.fit.fit.beta[joint_offset + i] *= s;
}
}
if let Some(state) = result.fit.fit.block_states.get_mut(block_idx) {
state.beta.mapv_inplace(|v| v * s);
state.eta.mapv_inplace(|v| v * s);
}
}
BlockRole::Scale => {
for col in scale_intercept_range.clone() {
if col < block.beta.len() {
block.beta[col] += ln_s;
}
let joint_col = joint_offset + col;
if joint_col < result.fit.fit.beta.len() {
result.fit.fit.beta[joint_col] += ln_s;
}
if let Some(state) = result.fit.fit.block_states.get_mut(block_idx)
&& col < state.beta.len()
{
state.beta[col] += ln_s;
}
}
if let Some(state) = result.fit.fit.block_states.get_mut(block_idx) {
state.eta.mapv_inplace(|v| v + ln_s);
}
}
BlockRole::Time | BlockRole::Threshold => {
}
}
joint_offset += block_len;
}
if let Some(knots) = result.wiggle_knots.as_mut() {
knots.mapv_inplace(|v| v * s);
}
if let Some(beta_w) = result.beta_link_wiggle.as_mut() {
for coef in beta_w.iter_mut() {
*coef *= s;
}
}
let mut row_factors: Vec<f64> = Vec::new();
for block in &result.fit.fit.blocks {
let f = match block.role {
BlockRole::Mean | BlockRole::Location | BlockRole::LinkWiggle => s,
BlockRole::Scale | BlockRole::Time | BlockRole::Threshold => 1.0,
};
row_factors.extend(std::iter::repeat_n(f, block.beta.len()));
}
let rescale_cov = |cov: &mut Array2<f64>| {
let m = cov.nrows().min(cov.ncols()).min(row_factors.len());
for i in 0..m {
for j in 0..m {
cov[[i, j]] *= row_factors[i] * row_factors[j];
}
}
};
if let Some(cov) = result.fit.fit.covariance_conditional.as_mut() {
rescale_cov(cov);
}
if let Some(cov) = result.fit.fit.covariance_corrected.as_mut() {
rescale_cov(cov);
}
result.fit.fit.standard_deviation *= s;
result.fit.fit.max_abs_eta *= s;
if let Some(n_obs) = result
.fit
.fit
.block_states
.first()
.map(|state| state.eta.len() as f64)
.filter(|&n| n > 0.0)
{
let ln_s = s.ln();
result.fit.fit.log_likelihood -= n_obs * ln_s;
result.fit.fit.deviance += 2.0 * n_obs * ln_s;
result.fit.fit.reml_score += n_obs * ln_s;
result.fit.fit.penalized_objective += n_obs * ln_s;
}
result.response_scale = s;
}
pub(crate) fn fit_gaussian_location_scale_model(
mut request: GaussianLocationScaleFitRequest<'_>,
) -> Result<GaussianLocationScaleFitResult, String> {
let response_scale = gaussian_response_sample_std(request.spec.y.view()).max(1e-6);
if response_scale != 1.0 {
request.spec.y.mapv_inplace(|v| v / response_scale);
request
.spec
.mean_offset
.mapv_inplace(|v| v / response_scale);
}
let mut result =
fit_location_scale_with_optional_wiggle::<GaussianLocationScaleWorkflow>(request)?;
rescale_gaussian_location_scale_to_raw(&mut result, response_scale);
Ok(result)
}
pub(crate) fn fit_dispersion_location_scale_model(
request: DispersionLocationScaleFitRequest<'_>,
) -> Result<DispersionLocationScaleFitResult, String> {
let kind = request.spec.kind;
let fit = fit_dispersion_glm_location_scale_terms(
request.data,
request.spec,
&request.options,
&request.kappa_options,
)?;
Ok(DispersionLocationScaleFitResult { fit, kind })
}
pub(crate) fn fit_binomial_location_scale_model(
request: BinomialLocationScaleFitRequest<'_>,
) -> Result<BinomialLocationScaleFitResult, String> {
fit_location_scale_with_optional_wiggle::<BinomialLocationScaleWorkflow>(request)
}
fn survival_working_reml_score(state: &crate::pirls::WorkingState) -> f64 {
0.5 * (state.deviance + state.penalty_term)
}
fn fitted_weibull_baseline_from_linear_time_beta(
beta: &Array1<f64>,
anchor: f64,
) -> Option<crate::families::survival::construction::SurvivalBaselineConfig> {
if beta.len() < 2 {
return None;
}
let shape = beta[1];
if !shape.is_finite() || shape <= 0.0 {
return None;
}
if !anchor.is_finite() || anchor <= 0.0 {
return None;
}
let scale = anchor;
Some(
crate::families::survival::construction::SurvivalBaselineConfig {
target: SurvivalBaselineTarget::Weibull,
scale: Some(scale),
shape: Some(shape),
rate: None,
makeham: None,
},
)
}
fn survival_transformation_edf(
state: &crate::pirls::WorkingState,
penalty_blocks: &[PenaltyBlock],
) -> Result<(f64, Vec<f64>, Vec<f64>, Array2<f64>), String> {
let h_dense = state.hessian.to_dense();
let p = h_dense.nrows();
let h_sym = crate::linalg::matrix::SymmetricMatrix::Dense(h_dense.clone());
let factor = {
let scale = h_sym.max_abs_diag();
let min_step = scale * 1e-10;
let mut ridge = 0.0_f64;
let mut attempts = 0_usize;
loop {
let candidate = if ridge > 0.0 {
h_sym.addridge(ridge).unwrap_or_else(|_| h_sym.clone())
} else {
h_sym.clone()
};
if let Ok(f) = candidate.factorize() {
break f;
}
attempts += 1;
if attempts >= 8 {
return Err("survival edf: penalized Hessian could not be factorized".to_string());
}
ridge = if ridge <= 0.0 { min_step } else { ridge * 10.0 };
}
};
let mut edf_by_block = vec![0.0_f64; penalty_blocks.len()];
let mut penalty_block_trace = vec![0.0_f64; penalty_blocks.len()];
let mut total_trace = 0.0_f64;
for (kk, block) in penalty_blocks.iter().enumerate() {
let block_cols = block.range.end - block.range.start;
if block.lambda <= 0.0 || block_cols == 0 {
edf_by_block[kk] = block_cols as f64;
penalty_block_trace[kk] = 0.0;
continue;
}
let mut rhs = Array2::<f64>::zeros((p, block_cols));
for c in 0..block_cols {
for r in 0..block_cols {
rhs[[block.range.start + r, c]] = block.matrix[[r, c]];
}
}
let sol = factor
.solvemulti(&rhs)
.map_err(|e| format!("survival edf trace solve failed: {e}"))?;
let mut trace = 0.0_f64;
for j in 0..block_cols {
trace += sol[[block.range.start + j, j]];
}
let lam_trace = (block.lambda * trace).clamp(0.0, block_cols as f64);
total_trace += lam_trace;
penalty_block_trace[kk] = lam_trace;
edf_by_block[kk] = (block_cols as f64 - lam_trace).clamp(0.0, block_cols as f64);
}
let edf_total = (p as f64 - total_trace).clamp(0.0, p as f64);
if !edf_total.is_finite()
|| edf_by_block.iter().any(|v| !v.is_finite())
|| penalty_block_trace.iter().any(|v| !v.is_finite())
{
return Err("survival edf: non-finite effective degrees of freedom".to_string());
}
Ok((edf_total, edf_by_block, penalty_block_trace, h_dense))
}
fn optimize_survival_transformation_smoothing(
model: &crate::families::survival::WorkingModelSurvival,
penalty_blocks: &[PenaltyBlock],
num_smoothing: usize,
beta0: &Array1<f64>,
structural_lower_bounds: Option<&Array1<f64>>,
) -> Result<Option<Vec<f64>>, String> {
use crate::solver::rho_optimizer::{Derivative, HessianResult, OuterEval, OuterProblem};
if num_smoothing == 0 {
return Ok(None);
}
let seed_lambdas: Vec<f64> = penalty_blocks.iter().map(|b| b.lambda).collect();
let seed_rho = Array1::from_iter(
seed_lambdas
.iter()
.take(num_smoothing)
.map(|&l| l.max(1e-12).ln()),
);
let eval_cache: std::cell::RefCell<Option<(Array1<f64>, f64, Array1<f64>)>> =
std::cell::RefCell::new(None);
let eval_at = |rho_smooth: &Array1<f64>| -> Result<(f64, Array1<f64>), String> {
if let Some((cached_rho, cached_cost, cached_grad)) = eval_cache.borrow().as_ref()
&& cached_rho == rho_smooth
{
return Ok((*cached_cost, cached_grad.clone()));
}
let mut candidate = model.clone();
let mut lambdas = seed_lambdas.clone();
for k in 0..num_smoothing {
lambdas[k] = rho_smooth[k].exp();
}
candidate
.set_penalty_lambdas(&lambdas)
.map_err(|e| e.to_string())?;
let opts = crate::pirls::WorkingModelPirlsOptions {
max_iterations: SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS,
convergence_tolerance: SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL,
adaptive_kkt_tolerance: None,
max_step_halving: SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING,
min_step_size: SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE,
firth_bias_reduction: false,
coefficient_lower_bounds: structural_lower_bounds.cloned(),
linear_constraints: None,
initial_lm_lambda: None,
geodesic_acceleration: false,
arrow_schur: None,
};
let summary = crate::pirls::runworking_model_pirls(
&mut candidate,
crate::types::Coefficients::new(beta0.clone()),
&opts,
|_| {},
)
.map_err(|err| format!("survival smoothing PIRLS failed: {err}"))?;
let bad_trial = |reason: &str| -> Result<(f64, Array1<f64>), String> {
log::info!(
"[OUTER #1123] survival transformation smoothing candidate ρ rejected ({reason}): \
inner PIRLS status={:?} grad_norm={:.3e} iters={} — returning high finite cost so \
BFGS steps away from the un-fittable region toward the converged seed",
summary.status,
summary.lastgradient_norm,
summary.iterations,
);
let cost = SURVIVAL_TRANSFORMATION_NONCONVERGED_TRIAL_COST;
let grad = Array1::zeros(num_smoothing);
*eval_cache.borrow_mut() = Some((rho_smooth.to_owned(), cost, grad.clone()));
Ok((cost, grad))
};
let inner_converged = matches!(
summary.status,
crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum
);
if !inner_converged {
return bad_trial("inner PIRLS did not converge");
}
let beta = summary.beta.as_ref().to_owned();
let state = match candidate.update_state(&beta) {
Ok(state) => state,
Err(_) => return bad_trial("inner state evaluation failed"),
};
let full_rho = Array1::from_iter(lambdas.iter().filter(|&&l| l > 0.0).map(|&l| l.ln()));
let (cost, grad_full) =
match candidate.unified_lamlobjective_and_rhogradient(&beta, &state, &full_rho) {
Ok(pair) => pair,
Err(_) => return bad_trial("LAML evaluation failed"),
};
if grad_full.len() < num_smoothing || !cost.is_finite() {
return bad_trial("LAML cost non-finite or gradient too short");
}
let grad = grad_full.slice(s![..num_smoothing]).to_owned();
if grad.iter().any(|g| !g.is_finite()) {
return bad_trial("LAML gradient non-finite");
}
*eval_cache.borrow_mut() = Some((rho_smooth.to_owned(), cost, grad.clone()));
Ok((cost, grad))
};
let lower = seed_rho.mapv(|v| v - 12.0);
let upper = seed_rho.mapv(|v| v + 12.0);
let problem = OuterProblem::new(num_smoothing)
.with_gradient(Derivative::Analytic)
.with_hessian(crate::solver::rho_optimizer::DeclaredHessianForm::Unavailable)
.with_tolerance(1e-4)
.with_max_iter(120)
.with_bounds(lower, upper)
.with_initial_rho(seed_rho.clone())
.with_seed_config(crate::seeding::SeedConfig {
max_seeds: 1,
seed_budget: 1,
..Default::default()
});
let context =
format!("survival transformation smoothing-parameter selection (dim={num_smoothing})");
let mut obj = problem.build_objective(
(),
|_: &mut (), rho: &Array1<f64>| {
eval_at(rho)
.map(|(c, _)| c)
.map_err(crate::estimate::EstimationError::InvalidInput)
},
|_: &mut (), rho: &Array1<f64>| {
let (cost, gradient) =
eval_at(rho).map_err(crate::estimate::EstimationError::InvalidInput)?;
Ok(OuterEval {
cost,
gradient,
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
})
},
None::<fn(&mut ())>,
None::<
fn(
&mut (),
&Array1<f64>,
)
-> Result<crate::solver::rho_optimizer::EfsEval, crate::estimate::EstimationError>,
>,
);
let result = match problem.run(&mut obj, &context) {
Ok(result) => result,
Err(err) => {
log::warn!(
"[#1123] survival transformation smoothing selector did not produce a usable ρ \
({err}); falling back to the seed λ (the CLI fits this same model at the seed and \
recovers the truth)"
);
return Ok(Some(seed_lambdas));
}
};
let selected_rho = result.rho;
let mut lambdas = seed_lambdas;
for k in 0..num_smoothing.min(selected_rho.len()) {
let lam = selected_rho[k].exp();
if lam.is_finite() && lam > 0.0 {
lambdas[k] = lam;
}
}
Ok(Some(lambdas))
}
fn survival_unified_fit_result(
beta: Array1<f64>,
lambdas: Array1<f64>,
summary: &crate::pirls::WorkingModelPirlsResult,
state: &crate::pirls::WorkingState,
penalty_blocks: &[PenaltyBlock],
) -> Result<UnifiedFitResult, String> {
let log_lambdas = lambdas.mapv(|v| v.max(LOG_LAMBDA_UNDERFLOW_FLOOR).ln());
let reml_score = survival_working_reml_score(state);
let outer_converged = matches!(
summary.status,
crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum
);
crate::estimate::validate_all_finite("survival fit beta", beta.iter().copied())?;
crate::estimate::validate_all_finite("survival fit lambdas", lambdas.iter().copied())?;
crate::estimate::ensure_finite_scalar("survival fit log_likelihood", state.log_likelihood)?;
crate::estimate::ensure_finite_scalar("survival fit deviance", state.deviance)?;
crate::estimate::ensure_finite_scalar("survival fit penalty", state.penalty_term)?;
crate::estimate::ensure_finite_scalar("survival fit reml_score", reml_score)?;
crate::estimate::ensure_finite_scalar("survival fit gradient_norm", summary.lastgradient_norm)?;
crate::estimate::ensure_finite_scalar("survival fit max_abs_eta", summary.max_abs_eta)?;
let (edf_total, edf_by_block, penalty_block_trace, penalized_hessian) =
survival_transformation_edf(state, penalty_blocks)?;
assert_eq!(edf_by_block.len(), lambdas.len());
assert_eq!(penalty_block_trace.len(), lambdas.len());
let inference = crate::estimate::FitInference {
edf_by_block: edf_by_block.clone(),
penalty_block_trace,
edf_total,
smoothing_correction: None,
penalized_hessian: penalized_hessian.into(),
working_weights: Array1::zeros(0),
working_response: Array1::zeros(0),
reparam_qs: None,
dispersion: crate::estimate::Dispersion::Known(1.0),
beta_covariance: None,
beta_standard_errors: None,
beta_covariance_corrected: None,
beta_standard_errors_corrected: None,
beta_covariance_frequentist: None,
coefficient_influence: None,
weighted_gram: None,
bias_correction_beta: None,
};
UnifiedFitResult::try_from_parts(crate::estimate::UnifiedFitResultParts {
blocks: vec![crate::estimate::FittedBlock {
beta: beta.clone(),
role: crate::estimate::BlockRole::Mean,
edf: edf_total,
lambdas: lambdas.clone(),
}],
log_lambdas,
lambdas,
likelihood_family: Some(LikelihoodSpec::royston_parmar()),
likelihood_scale: crate::types::LikelihoodScaleMetadata::Unspecified,
log_likelihood_normalization: crate::types::LogLikelihoodNormalization::UserProvided,
log_likelihood: state.log_likelihood,
deviance: state.deviance,
reml_score,
stable_penalty_term: state.penalty_term,
penalized_objective: reml_score,
used_device: false,
outer_iterations: summary.iterations,
outer_converged,
outer_gradient_norm: Some(summary.lastgradient_norm),
standard_deviation: 1.0,
covariance_conditional: None,
covariance_corrected: None,
inference: Some(inference),
fitted_link: FittedLinkState::Standard(None),
geometry: None,
block_states: Vec::new(),
pirls_status: summary.status,
max_abs_eta: summary.max_abs_eta,
constraint_kkt: None,
artifacts: crate::estimate::FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})
.map_err(|err| err.to_string())
}
pub(crate) fn replicate_pooled_baseline_seed_per_cause(
pooled_seed: ArrayView1<'_, f64>,
cause_count: usize,
) -> Array1<f64> {
let p = pooled_seed.len();
let mut beta0_flat = Array1::<f64>::zeros(p * cause_count);
for cause in 0..cause_count {
beta0_flat
.slice_mut(s![cause * p..(cause + 1) * p])
.assign(&pooled_seed);
}
beta0_flat
}
fn fit_cause_specific_survival_transformation_custom(
spec: &SurvivalTransformationTermSpec,
resolvedspec: TermCollectionSpec,
baseline_cfg: crate::families::survival::construction::SurvivalBaselineConfig,
prepared: PreparedSurvivalTimeStack,
dense_cov_design: &Array2<f64>,
penalty_blocks: Vec<PenaltyBlock>,
beta0_flat: Array1<f64>,
derivative_floor: f64,
penalty_block_gamma_priors: &[(String, f64, f64)],
) -> Result<SurvivalTransformationFitResult, String> {
let cause_count =
crate::families::survival::cause_count_from_event_codes(spec.event_target.view())
.into_workflow_result()?;
if cause_count == 0 {
return Err(WorkflowError::MissingDependency {
reason: "cause-specific custom survival fit requires at least one cause".to_string(),
}
.into());
}
let n = spec.event_target.len();
let p_time_total = prepared.time_design_exit.ncols();
let p_cov = dense_cov_design.ncols();
let p = p_time_total + p_cov;
if beta0_flat.len() != p * cause_count {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"cause-specific survival initial beta length mismatch: got {}, expected {}",
beta0_flat.len(),
p * cause_count
),
}
.into());
}
let dense_time_entry = prepared.time_design_entry.to_dense();
let dense_time_exit = prepared.time_design_exit.to_dense();
let dense_time_derivative = prepared.time_design_derivative_exit.to_dense();
let mut x_entry = Array2::<f64>::zeros((n, p));
let mut x_exit = Array2::<f64>::zeros((n, p));
let mut x_derivative = Array2::<f64>::zeros((n, p));
if p_time_total > 0 {
x_entry
.slice_mut(s![.., ..p_time_total])
.assign(&dense_time_entry);
x_exit
.slice_mut(s![.., ..p_time_total])
.assign(&dense_time_exit);
x_derivative
.slice_mut(s![.., ..p_time_total])
.assign(&dense_time_derivative);
}
if p_cov > 0 {
x_entry
.slice_mut(s![.., p_time_total..])
.assign(dense_cov_design);
x_exit
.slice_mut(s![.., p_time_total..])
.assign(dense_cov_design);
}
let mut family_blocks = Vec::with_capacity(cause_count);
let mut block_specs = Vec::with_capacity(cause_count);
for cause in 0..cause_count {
let cause_code = (cause + 1) as u8;
let event_target = spec
.event_target
.mapv(|observed| u8::from(observed == cause_code));
family_blocks.push(crate::families::survival::CauseSpecificRoystonParmarBlock {
age_entry: spec.age_entry.clone(),
age_exit: spec.age_exit.clone(),
event_target,
sampleweight: spec.weights.clone(),
x_entry: x_entry.clone(),
x_exit: x_exit.clone(),
x_derivative: x_derivative.clone(),
offset_eta_entry: prepared.eta_offset_entry.clone() + &spec.covariate_offset,
offset_eta_exit: prepared.eta_offset_exit.clone() + &spec.covariate_offset,
offset_derivative_exit: prepared.derivative_offset_exit.clone(),
derivative_floor,
});
let mut penalties = Vec::with_capacity(penalty_blocks.len());
let mut nullspace_dims = Vec::with_capacity(penalty_blocks.len());
let mut initial_log_lambdas = Array1::<f64>::zeros(penalty_blocks.len());
for (penalty_idx, block) in penalty_blocks.iter().enumerate() {
if block.range.end > p || block.range.start > block.range.end {
return Err(WorkflowError::SchemaMismatch {
reason: "cause-specific survival penalty range is out of bounds".to_string(),
}
.into());
}
let block_dim = block.range.end - block.range.start;
if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"cause-specific survival penalty {penalty_idx} has shape {}x{} but range has width {block_dim}",
block.matrix.nrows(),
block.matrix.ncols()
),
}
.into());
}
penalties.push(
PenaltyMatrix::Blockwise {
local: block.matrix.clone(),
col_range: block.range.clone(),
total_dim: p,
}
.with_precision_label(format!(
"cause_specific_survival_cause_{}_penalty_{penalty_idx}",
cause + 1
)),
);
nullspace_dims.push(block.nullspace_dim);
initial_log_lambdas[penalty_idx] = block.lambda.max(LOG_LAMBDA_UNDERFLOW_FLOOR).ln();
}
let beta_start = beta0_flat.slice(s![cause * p..(cause + 1) * p]).to_owned();
let cause_priority =
100u8.saturating_add(u8::try_from(cause_count - cause).unwrap_or(u8::MAX));
let cause_jacobian = std::sync::Arc::new(AdditiveBlockJacobian {
design: x_exit.clone(),
own_output: cause,
n_family_outputs: cause_count,
});
block_specs.push(ParameterBlockSpec {
name: format!("time_cause_{}", cause + 1),
design: crate::matrix::DesignMatrix::from(x_exit.clone()),
offset: prepared.eta_offset_exit.clone() + &spec.covariate_offset,
penalties,
nullspace_dims,
initial_log_lambdas,
initial_beta: Some(beta_start),
gauge_priority: cause_priority,
jacobian_callback: Some(cause_jacobian),
stacked_design: None,
stacked_offset: None,
});
}
let family = crate::families::survival::CauseSpecificRoystonParmarFamily::new(family_blocks)?;
let fit_options = BlockwiseFitOptions {
compute_covariance: false,
..Default::default()
};
let rho_prior = cause_specific_survival_rho_prior(
cause_count,
penalty_blocks.len(),
penalty_block_gamma_priors,
)?;
let mut fit = fit_custom_family_with_rho_prior(&family, &block_specs, &fit_options, rho_prior)
.map_err(|err| format!("cause-specific survival custom-family fit failed: {err}"))?;
fit.likelihood_family = Some(LikelihoodSpec::royston_parmar());
let time_basis = crate::families::survival::construction::SavedSurvivalTimeBasis::from_build(
&spec.time_build,
spec.time_anchor,
);
let fitted_baseline_cfg = if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull
&& spec.timewiggle.is_none()
{
let first_block = fit.blocks.first().ok_or_else(|| {
"cause-specific survival fit produced no coefficient blocks".to_string()
})?;
let time_beta = first_block
.beta
.slice(s![..spec.time_build.x_exit_time.ncols()])
.to_owned();
fitted_weibull_baseline_from_linear_time_beta(&time_beta, spec.time_anchor).ok_or_else(|| {
"failed to recover fitted Weibull scale/shape from the cause-specific linear time coefficients"
.to_string()
})?
} else {
baseline_cfg
};
Ok(SurvivalTransformationFitResult {
fit,
resolvedspec,
baseline_cfg: fitted_baseline_cfg,
likelihood_mode: spec.likelihood_mode,
time_basis,
time_base_ncols: spec.time_build.x_exit_time.ncols(),
baseline_timewiggle: prepared.timewiggle_block,
})
}
fn cause_specific_survival_rho_prior(
cause_count: usize,
penalty_count: usize,
penalty_block_gamma_priors: &[(String, f64, f64)],
) -> Result<crate::types::RhoPrior, String> {
if penalty_block_gamma_priors.is_empty() {
return Ok(crate::types::RhoPrior::Flat);
}
let mut keyed = BTreeMap::<String, (f64, f64)>::new();
for (label, shape, rate) in penalty_block_gamma_priors {
if keyed.insert(label.clone(), (*shape, *rate)).is_some() {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"duplicate Gamma precision hyperprior for penalty block label '{label}'"
),
}
.into());
}
if !shape.is_finite() || *shape <= 0.0 {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"Gamma precision hyperprior for penalty block '{label}' requires shape > 0, got {shape}"
),
}
.into());
}
if !rate.is_finite() || *rate < 0.0 {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"Gamma precision hyperprior for penalty block '{label}' requires rate >= 0, got {rate}"
),
}
.into());
}
}
let mut consumed = Vec::<String>::new();
let mut priors = Vec::<crate::types::RhoPrior>::with_capacity(cause_count * penalty_count);
for cause in 0..cause_count {
for penalty_idx in 0..penalty_count {
let label = format!(
"cause_specific_survival_cause_{}_penalty_{penalty_idx}",
cause + 1
);
if let Some((shape, rate)) = keyed.get(&label) {
consumed.push(label);
priors.push(crate::types::RhoPrior::GammaPrecision {
shape: *shape,
rate: *rate,
});
} else {
priors.push(crate::types::RhoPrior::Flat);
}
}
}
let unknown = keyed
.keys()
.filter(|label| !consumed.iter().any(|known| known == *label))
.cloned()
.collect::<Vec<_>>();
if !unknown.is_empty() {
let available = (0..cause_count)
.flat_map(|cause| {
(0..penalty_count).map(move |idx| {
format!("cause_specific_survival_cause_{}_penalty_{idx}", cause + 1)
})
})
.collect::<Vec<_>>()
.join(", ");
return Err(WorkflowError::InvalidConfig {
reason: format!(
"unknown Gamma precision hyperprior penalty block label(s): {}; available labels: {available}",
unknown.join(", ")
),
}
.into());
}
Ok(crate::types::RhoPrior::Independent(priors))
}
fn hash_workflow_array_view(
hasher: &mut crate::warm_start::Fingerprinter,
array: ArrayView1<'_, f64>,
) {
hasher.write_usize(array.len());
for &value in array {
hasher.write_f64(value);
}
}
fn hash_workflow_u8_array(
hasher: &mut crate::warm_start::Fingerprinter,
array: ArrayView1<'_, u8>,
) {
hasher.write_usize(array.len());
for &value in array {
hasher.write_usize(usize::from(value));
}
}
fn hash_workflow_array2(hasher: &mut crate::warm_start::Fingerprinter, array: ArrayView2<'_, f64>) {
hasher.write_usize(array.nrows());
hasher.write_usize(array.ncols());
for row in array.rows() {
for &value in row {
hasher.write_f64(value);
}
}
}
fn hash_workflow_design_matrix(
hasher: &mut crate::warm_start::Fingerprinter,
matrix: &crate::matrix::DesignMatrix,
) {
let dense = matrix.to_dense();
hash_workflow_array2(hasher, dense.view());
}
fn survival_transformation_log_lambdas(
penalty_blocks: &[crate::families::survival::PenaltyBlock],
) -> Vec<f64> {
penalty_blocks
.iter()
.map(|block| block.lambda.max(LOG_LAMBDA_UNDERFLOW_FLOOR).ln())
.collect()
}
fn persistent_survival_transformation_key(
spec: &SurvivalTransformationTermSpec,
baseline_cfg: &crate::families::survival::construction::SurvivalBaselineConfig,
dense_cov_design: ArrayView2<'_, f64>,
prepared: &PreparedSurvivalTimeStack,
penalty_blocks: &[crate::families::survival::PenaltyBlock],
opts: &crate::pirls::WorkingModelPirlsOptions,
n_cols: usize,
) -> String {
let mut hasher = crate::warm_start::Fingerprinter::new();
hasher.write_str("gamfit-persistent-survival-transformation-working-pirls");
hasher.write_str(&crate::solver::persistent_warm_start::cache_schema_tag());
hasher.write_str(&format!("{:?}", spec.likelihood_mode));
hasher.write_f64(spec.time_anchor);
hasher.write_f64(spec.ridge_lambda);
hasher.write_str(&format!("{:?}", baseline_cfg.target));
for value in [
baseline_cfg.scale,
baseline_cfg.shape,
baseline_cfg.rate,
baseline_cfg.makeham,
] {
hasher.write_bool(value.is_some());
if let Some(value) = value {
hasher.write_f64(value);
}
}
hasher.write_str(&spec.time_build.basisname);
hasher.write_usize(spec.time_build.x_entry_time.nrows());
hasher.write_usize(spec.time_build.x_entry_time.ncols());
hasher.write_usize(spec.time_build.x_exit_time.nrows());
hasher.write_usize(spec.time_build.x_exit_time.ncols());
hasher.write_usize(spec.time_build.x_derivative_time.nrows());
hasher.write_usize(spec.time_build.x_derivative_time.ncols());
hasher.write_bool(spec.time_build.degree.is_some());
if let Some(degree) = spec.time_build.degree {
hasher.write_usize(degree);
}
match spec.time_build.knots.as_ref() {
Some(knots) => {
hasher.write_bool(true);
hasher.write_usize(knots.len());
for &knot in knots {
hasher.write_f64(knot);
}
}
None => hasher.write_bool(false),
}
match spec.time_build.keep_cols.as_ref() {
Some(cols) => {
hasher.write_bool(true);
hasher.write_usize(cols.len());
for &col in cols {
hasher.write_usize(col);
}
}
None => hasher.write_bool(false),
}
hasher.write_bool(spec.time_build.smooth_lambda.is_some());
if let Some(lambda) = spec.time_build.smooth_lambda {
hasher.write_f64(lambda);
}
hasher.write_usize(n_cols);
hash_workflow_array_view(&mut hasher, spec.age_entry.view());
hash_workflow_array_view(&mut hasher, spec.age_exit.view());
hash_workflow_u8_array(&mut hasher, spec.event_target.view());
hash_workflow_array_view(&mut hasher, spec.weights.view());
hash_workflow_array_view(&mut hasher, spec.covariate_offset.view());
hash_workflow_array2(&mut hasher, dense_cov_design);
hash_workflow_array_view(&mut hasher, prepared.eta_offset_entry.view());
hash_workflow_array_view(&mut hasher, prepared.eta_offset_exit.view());
hash_workflow_array_view(&mut hasher, prepared.derivative_offset_exit.view());
hash_workflow_design_matrix(&mut hasher, &prepared.time_design_entry);
hash_workflow_design_matrix(&mut hasher, &prepared.time_design_exit);
hash_workflow_design_matrix(&mut hasher, &prepared.time_design_derivative_exit);
hasher.write_usize(penalty_blocks.len());
for block in penalty_blocks {
hasher.write_f64(block.lambda);
hasher.write_usize(block.range.start);
hasher.write_usize(block.range.end);
hasher.write_usize(block.nullspace_dim);
hash_workflow_array2(&mut hasher, block.matrix.view());
}
hasher.write_usize(opts.max_iterations);
hasher.write_f64(opts.convergence_tolerance);
hasher.write_usize(opts.max_step_halving);
hasher.write_f64(opts.min_step_size);
hasher.write_bool(opts.firth_bias_reduction);
hasher.write_bool(opts.coefficient_lower_bounds.is_some());
if let Some(bounds) = opts.coefficient_lower_bounds.as_ref() {
hash_workflow_array_view(&mut hasher, bounds.view());
}
hasher.write_bool(opts.linear_constraints.is_some());
format!("surv-transform-{}", hasher.finish_hex())
}
fn load_survival_transformation_persistent_warm_start(
key: &str,
spec: &SurvivalTransformationTermSpec,
n_cols: usize,
rho: &[f64],
) -> Option<(Array1<f64>, Option<f64>)> {
let record = crate::solver::persistent_warm_start::load_record(key)?;
if !record.is_compatible(key, spec.age_entry.len(), n_cols)
|| record.rho.len() != rho.len()
|| !record
.rho
.iter()
.zip(rho.iter())
.all(|(cached, expected)| (*cached - *expected).abs() <= 1e-10)
{
return None;
}
log::info!("[warm-start-cache] restored survival transformation warm start key={key}");
let lm_lambda = record
.last_pirls_lm_lambda
.filter(|value| value.is_finite() && *value > 0.0);
Some((Array1::from_vec(record.beta), lm_lambda))
}
fn store_survival_transformation_persistent_warm_start(
key: &str,
spec: &SurvivalTransformationTermSpec,
n_cols: usize,
rho: Vec<f64>,
beta: &Array1<f64>,
summary: &crate::pirls::WorkingModelPirlsResult,
) {
if beta.len() != n_cols
|| beta.iter().any(|value| !value.is_finite())
|| rho.iter().any(|value| !value.is_finite())
{
return;
}
let mut record = crate::solver::persistent_warm_start::PersistentWarmStartRecord::new(
key.to_string(),
spec.age_entry.len(),
n_cols,
);
record.rho = rho;
record.beta = beta.to_vec();
record.last_inner_iters = summary.iterations;
record.last_inner_converged = matches!(
summary.status,
crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum
);
record.last_pirls_lm_lambda = (summary.final_lm_lambda.is_finite()
&& summary.final_lm_lambda > 0.0)
.then_some(summary.final_lm_lambda);
record.last_pirls_accept_rho = summary
.final_accept_rho
.filter(|value| value.is_finite() && *value >= 0.0);
if let Err(err) = crate::solver::persistent_warm_start::store_record(&record) {
log::warn!(
"[warm-start-cache] failed to persist survival transformation warm start: {err}"
);
}
}
pub(crate) fn fit_survival_transformation_model(
request: SurvivalTransformationFitRequest<'_>,
) -> Result<SurvivalTransformationFitResult, String> {
use crate::families::survival::{
PenaltyBlock, PenaltyBlocks, SurvivalMonotonicityPenalty, SurvivalSpec,
};
let SurvivalTransformationFitRequest {
data,
spec,
cache_session: _cache_session,
} = request;
let mut baseline_cfg = spec.baseline_cfg.clone();
let covariate_design =
build_term_collection_design(data, &spec.covariate_spec).map_err(|err| err.to_string())?;
let resolvedspec =
crate::smooth::freeze_term_collection_from_design(&spec.covariate_spec, &covariate_design)
.map_err(|err| err.to_string())?;
let dense_cov_design = covariate_design.design.to_dense();
let p_cov = dense_cov_design.ncols();
let cause_count =
crate::families::survival::cause_count_from_event_codes(spec.event_target.view())
.into_workflow_result()?;
let exact_derivative_guard = survival_derivative_guard_for_likelihood(spec.likelihood_mode);
let build_working_model =
|candidate: &crate::families::survival::construction::SurvivalBaselineConfig| {
let prepared = prepare_survival_time_stack(
&spec.age_entry,
&spec.age_exit,
candidate,
spec.likelihood_mode,
None,
spec.time_anchor,
exact_derivative_guard,
&spec.time_build,
spec.timewiggle.as_ref(),
None,
)?;
let mut eta_offset_entry = prepared.eta_offset_entry.clone();
let mut eta_offset_exit = prepared.eta_offset_exit.clone();
eta_offset_entry += &spec.covariate_offset;
eta_offset_exit += &spec.covariate_offset;
let p_time_total = prepared.time_design_exit.ncols();
let p = p_time_total + p_cov;
let mut penalty_blocks = Vec::<PenaltyBlock>::new();
for (idx, penalty) in prepared.time_penalties.iter().enumerate() {
if penalty.nrows() == p_time_total && penalty.ncols() == p_time_total {
penalty_blocks.push(PenaltyBlock {
matrix: penalty.clone(),
lambda: spec.time_build.smooth_lambda.unwrap_or(1e-2),
range: 0..p_time_total,
nullspace_dim: prepared.time_nullspace_dims.get(idx).copied().unwrap_or(0),
});
}
}
for (penalty_idx, cov_penalty) in covariate_design.penalties.iter().enumerate() {
let cr = &cov_penalty.col_range;
let block_dim = cr.end - cr.start;
let matches_dims = cov_penalty.local.nrows() == block_dim
&& cov_penalty.local.ncols() == block_dim;
let zero_prior = matches!(
cov_penalty.prior_mean,
crate::estimate::CoefficientPriorMean::Zero
);
if block_dim > 0 && matches_dims && zero_prior && cr.end <= p_cov {
penalty_blocks.push(PenaltyBlock {
matrix: cov_penalty.local.clone(),
lambda: 1e-2,
range: (p_time_total + cr.start)..(p_time_total + cr.end),
nullspace_dim: covariate_design
.nullspace_dims
.get(penalty_idx)
.copied()
.unwrap_or(0),
});
}
}
let num_smoothing_blocks = penalty_blocks.len();
let ridge_range_start = if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull
&& spec.time_build.basisname == "linear"
&& spec.timewiggle.is_none()
{
1
} else {
0
};
if spec.ridge_lambda > 0.0 && p > ridge_range_start {
let dim = p - ridge_range_start;
let mut ridge = Array2::<f64>::zeros((dim, dim));
for d in 0..dim {
ridge[[d, d]] = 1.0;
}
penalty_blocks.push(PenaltyBlock {
matrix: ridge,
lambda: spec.ridge_lambda,
range: ridge_range_start..p,
nullspace_dim: 0,
});
}
let dense_time_entry = prepared.time_design_entry.to_dense();
let dense_time_exit = prepared.time_design_exit.to_dense();
let dense_time_derivative = prepared.time_design_derivative_exit.to_dense();
let event_competing = Array1::<u8>::zeros(spec.event_target.len());
let baseline_event_indicator = spec.event_target.mapv(|label| u8::from(label > 0));
let mut model =
crate::families::survival::royston_parmar::working_model_from_time_covariateshared(
PenaltyBlocks::new(penalty_blocks.clone()),
SurvivalMonotonicityPenalty { tolerance: 0.0 },
SurvivalSpec::Net,
crate::families::survival::royston_parmar::RoystonParmarSharedTimeCovariateInputs {
age_entry: spec.age_entry.view(),
age_exit: spec.age_exit.view(),
event_target: baseline_event_indicator.view(),
event_competing: event_competing.view(),
weights: spec.weights.view(),
time_entry: dense_time_entry.view(),
time_exit: dense_time_exit.view(),
time_derivative: dense_time_derivative.view(),
covariates: dense_cov_design.view(),
monotonicity_constraint_rows: None,
monotonicity_constraint_offsets: None,
eta_offset_entry: Some(eta_offset_entry.view()),
eta_offset_exit: Some(eta_offset_exit.view()),
derivative_offset_exit: Some(prepared.derivative_offset_exit.view()),
},
)
.map_err(|err| format!("failed to construct survival model: {err}"))?;
if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull {
model
.set_structural_monotonicity(true, p_time_total)
.map_err(|err| format!("failed to enable structural monotonicity: {err}"))?;
}
let mut beta0 = Array1::<f64>::zeros(p);
if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none()
{
let (scale, shape) = spec
.weibull_seed
.ok_or_else(|| "weibull survival fit missing scale/shape seed".to_string())?;
if p_time_total < 2 {
return Err(format!(
"weibull built-in time basis has {p_time_total} columns but needs 2 to seed scale/shape"
));
}
beta0[0] = -shape * scale.ln();
beta0[1] = shape;
}
let structural_lower_bounds =
if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull && p_time_total > 0 {
let mut lb = Array1::from_elem(p, f64::NEG_INFINITY);
for j in 0..p_time_total {
lb[j] = 0.0;
beta0[j] = 1e-4;
}
Some(lb)
} else {
None
};
Ok::<_, String>((
prepared,
penalty_blocks,
beta0,
structural_lower_bounds,
model,
num_smoothing_blocks,
))
};
if baseline_cfg.target != SurvivalBaselineTarget::Linear {
baseline_cfg = optimize_survival_baseline_config_with_gradient_only(
&baseline_cfg,
"workflow survival transformation baseline",
|candidate| {
let (_, _, beta0, structural_lower_bounds, mut model, _) =
build_working_model(candidate)?;
let opts = crate::pirls::WorkingModelPirlsOptions {
max_iterations: SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS,
convergence_tolerance: SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL,
adaptive_kkt_tolerance: None,
max_step_halving: SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING,
min_step_size: SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE,
firth_bias_reduction: false,
coefficient_lower_bounds: structural_lower_bounds,
linear_constraints: None,
initial_lm_lambda: None,
geodesic_acceleration: false,
arrow_schur: None,
};
let summary = crate::pirls::runworking_model_pirls(
&mut model,
crate::types::Coefficients::new(beta0),
&opts,
|_| {},
)
.map_err(|err| format!("survival PIRLS failed: {err}"))?;
let beta = summary.beta.as_ref().to_owned();
let state = model.update_state(&beta).map_err(|err| {
format!("failed to evaluate survival baseline candidate: {err}")
})?;
let cost = survival_working_reml_score(&state);
let residuals = model.offset_channel_residuals(&beta).map_err(|err| {
format!("failed to form survival baseline offset residuals: {err}")
})?;
let gradient = baseline_chain_rule_gradient(
spec.age_entry.view(),
spec.age_exit.view(),
spec.age_exit.view(),
candidate,
&residuals,
)?
.ok_or_else(|| {
"workflow survival transformation baseline unexpectedly has no theta gradient"
.to_string()
})?;
Ok((cost, gradient))
},
)?;
}
let (
prepared,
mut penalty_blocks,
beta0,
structural_lower_bounds,
mut model,
num_smoothing_blocks,
) = build_working_model(&baseline_cfg)?;
if cause_count > 1 || !spec.penalty_block_gamma_priors.is_empty() {
let beta0_flat = replicate_pooled_baseline_seed_per_cause(beta0.view(), cause_count);
return fit_cause_specific_survival_transformation_custom(
&spec,
resolvedspec,
baseline_cfg,
prepared,
&dense_cov_design,
penalty_blocks,
beta0_flat,
exact_derivative_guard,
&spec.penalty_block_gamma_priors,
);
}
if let Some(selected_lambdas) = optimize_survival_transformation_smoothing(
&model,
&penalty_blocks,
num_smoothing_blocks,
&beta0,
structural_lower_bounds.as_ref(),
)? {
model
.set_penalty_lambdas(&selected_lambdas)
.map_err(|e| e.to_string())?;
for (block, &lam) in penalty_blocks.iter_mut().zip(selected_lambdas.iter()) {
block.lambda = lam;
}
}
let opts = crate::pirls::WorkingModelPirlsOptions {
max_iterations: SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS,
convergence_tolerance: SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL,
adaptive_kkt_tolerance: None,
max_step_halving: SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING,
min_step_size: SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE,
firth_bias_reduction: false,
coefficient_lower_bounds: structural_lower_bounds,
linear_constraints: None,
initial_lm_lambda: None,
geodesic_acceleration: false,
arrow_schur: None,
};
let rho_for_cache = survival_transformation_log_lambdas(&penalty_blocks);
let persistent_warm_start_key = persistent_survival_transformation_key(
&spec,
&baseline_cfg,
dense_cov_design.view(),
&prepared,
&penalty_blocks,
&opts,
beta0.len(),
);
let mut opts = opts;
let beta_start = match load_survival_transformation_persistent_warm_start(
&persistent_warm_start_key,
&spec,
beta0.len(),
&rho_for_cache,
) {
Some((beta, lm_lambda)) => {
opts.initial_lm_lambda = lm_lambda;
beta
}
None => beta0,
};
let summary = crate::pirls::runworking_model_pirls(
&mut model,
crate::types::Coefficients::new(beta_start),
&opts,
|_| {},
)
.map_err(|err| format!("survival PIRLS failed: {err}"))?;
match summary.status {
crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum => {
}
ref other => {
let beta_finite = summary.beta.as_ref().iter().all(|v| v.is_finite());
let result_finite = beta_finite
&& summary.state.deviance.is_finite()
&& summary.lastgradient_norm.is_finite();
if result_finite {
log::warn!(
"[#1123] survival transformation inner PIRLS at the selected λ did not reach the \
convergence tolerance (status={other:?}, grad_norm={:.3e}, iterations={}, \
deviance={:.6e}), but landed at a finite optimum; accepting it rather than \
aborting the fit (the outer smoothing selector already steers away from \
un-fittable λ, and a finite optimum is a usable model).",
summary.lastgradient_norm,
summary.iterations,
summary.state.deviance,
);
} else {
return Err(WorkflowError::IntegrationFailed {
reason: format!(
"survival PIRLS did not converge to a finite optimum: status={other:?}, grad_norm={:.3e}, iterations={}, deviance={:.6e}",
summary.lastgradient_norm, summary.iterations, summary.state.deviance
),
}
.into());
}
}
}
let beta = summary.beta.as_ref().to_owned();
store_survival_transformation_persistent_warm_start(
&persistent_warm_start_key,
&spec,
beta.len(),
rho_for_cache,
&beta,
&summary,
);
let state = model
.update_state(&beta)
.map_err(|err| format!("failed to evaluate survival optimum: {err}"))?;
let lambdas = Array1::from_iter(penalty_blocks.iter().map(|block| block.lambda));
let fitted_baseline_cfg =
if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none() {
let time_beta = beta
.slice(s![..spec.time_build.x_exit_time.ncols()])
.to_owned();
fitted_weibull_baseline_from_linear_time_beta(&time_beta, spec.time_anchor).ok_or_else(
|| {
"failed to recover fitted Weibull scale/shape from the linear time coefficients"
.to_string()
},
)?
} else {
baseline_cfg
};
let fit = survival_unified_fit_result(beta, lambdas, &summary, &state, &penalty_blocks)?;
let time_base_ncols = spec.time_build.x_exit_time.ncols();
let time_basis = crate::families::survival::construction::SavedSurvivalTimeBasis::from_build(
&spec.time_build,
spec.time_anchor,
);
Ok(SurvivalTransformationFitResult {
fit,
resolvedspec,
baseline_cfg: fitted_baseline_cfg,
likelihood_mode: spec.likelihood_mode,
time_basis,
time_base_ncols,
baseline_timewiggle: prepared.timewiggle_block,
})
}
pub(crate) fn fit_survival_location_scale_model(
request: SurvivalLocationScaleFitRequest<'_>,
) -> Result<SurvivalLocationScaleFitResult, String> {
fn profile_survival_location_scale(
data: ArrayView2<'_, f64>,
spec: SurvivalLocationScaleTermSpec,
wiggle: Option<LinkWiggleConfig>,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<SurvivalLocationScaleProfile, String> {
let mut wiggle_knots = None;
let mut wiggle_degree = None;
let inverse_link = spec.inverse_link.clone();
let fit = if let Some(wiggle) = wiggle {
require_inverse_link_supports_joint_wiggle(&inverse_link, "survival link wiggle")?;
let mut pilot_spec = spec.clone();
pilot_spec.linkwiggle_block = None;
let pilot = fit_survival_location_scale_terms(data, pilot_spec, kappa_options)?;
let selected_wiggle_basis = select_survival_link_wiggle_basis_from_pilot(
&pilot,
&WiggleBlockConfig {
degree: wiggle.degree,
num_internal_knots: wiggle.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle.double_penalty,
},
&wiggle.penalty_orders,
)?;
wiggle_knots = Some(selected_wiggle_basis.knots.clone());
wiggle_degree = Some(selected_wiggle_basis.degree);
fit_survival_location_scale_terms_with_selected_wiggle(
data,
spec,
selected_wiggle_basis,
kappa_options,
)?
} else {
fit_survival_location_scale_terms(data, spec, kappa_options)?
};
Ok(SurvivalLocationScaleProfile {
fit,
inverse_link,
wiggle_knots,
wiggle_degree,
})
}
fn profile_survival_location_scale_with_inverse_link(
data: ArrayView2<'_, f64>,
spec: &SurvivalLocationScaleTermSpec,
inverse_link: InverseLink,
wiggle: Option<LinkWiggleConfig>,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<SurvivalLocationScaleProfile, String> {
let mut spec_at_link = spec.clone();
spec_at_link.inverse_link = inverse_link;
profile_survival_location_scale(data, spec_at_link, wiggle, kappa_options)
}
fn optimize_survival_inverse_link_profile(
data: ArrayView2<'_, f64>,
spec: &SurvivalLocationScaleTermSpec,
wiggle: Option<LinkWiggleConfig>,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<SurvivalLocationScaleProfile, String> {
fn optimize_link_parameters<R>(
data: ArrayView2<'_, f64>,
spec: &SurvivalLocationScaleTermSpec,
kappa_options: &SpatialLengthScaleOptimizationOptions,
init: Array1<f64>,
name: &str,
final_wiggle: Option<LinkWiggleConfig>,
wiggle_cfg: Option<LinkWiggleConfig>,
make_link: impl Fn(&Array1<f64>) -> Result<InverseLink, String> + Clone,
recover: R,
) -> Result<SurvivalLocationScaleProfile, String>
where
R: Fn(&Array1<f64>) -> Option<InverseLink>,
{
use crate::solver::rho_optimizer::{
DeclaredHessianForm, Derivative, HessianResult, OuterEval, OuterProblem,
};
let dim = init.len();
let lower = init.mapv(|v| v - 6.0);
let upper = init.mapv(|v| v + 6.0);
let problem = OuterProblem::new(dim)
.with_gradient(Derivative::Analytic)
.with_hessian(DeclaredHessianForm::Unavailable)
.with_tolerance(1e-4)
.with_max_iter(240)
.with_bounds(lower, upper)
.with_initial_rho(init.clone())
.with_seed_config(crate::seeding::SeedConfig {
max_seeds: 1,
seed_budget: 1,
num_auxiliary_trailing: dim,
..Default::default()
});
let context = format!("survival inverse-link optimization ({name}, dim={dim})");
let eval_link = move |theta: &Array1<f64>| -> Result<(f64, Array1<f64>), String> {
let link = make_link(theta)?;
let profile = profile_survival_location_scale_with_inverse_link(
data,
spec,
link,
wiggle_cfg.clone(),
kappa_options,
)?;
let cost =
-profile.fit.fit.log_likelihood + 0.5 * profile.fit.fit.stable_penalty_term;
if !cost.is_finite() {
return Err(format!(
"survival inverse-link ({name}): non-finite profile cost \
(log_likelihood={}, stable_penalty_term={})",
profile.fit.fit.log_likelihood, profile.fit.fit.stable_penalty_term
));
}
let gradient = profile
.fit
.link_param_data_fit_gradient
.clone()
.ok_or_else(|| {
format!(
"survival inverse-link ({name}): fit reported no link-parameter \
data-fit gradient"
)
})?;
if gradient.len() != theta.len() {
return Err(format!(
"survival inverse-link ({name}): gradient dim {} != theta dim {}",
gradient.len(),
theta.len()
));
}
Ok((cost, gradient))
};
let cost_eval = eval_link.clone();
let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
cost_eval(theta)
.map(|(cost, _)| cost)
.map_err(crate::estimate::EstimationError::InvalidInput)
};
let eval_fn = move |_: &mut (), theta: &Array1<f64>| {
let (cost, gradient) =
eval_link(theta).map_err(crate::estimate::EstimationError::InvalidInput)?;
Ok(OuterEval {
cost,
gradient,
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
})
};
let mut obj = problem.build_objective(
(),
cost_fn,
eval_fn,
None::<fn(&mut ())>,
None::<
fn(
&mut (),
&Array1<f64>,
) -> Result<
crate::solver::rho_optimizer::EfsEval,
crate::estimate::EstimationError,
>,
>,
);
let result = problem
.run(&mut obj, &context)
.map_err(|err| format!("{context} failed: {err}"))?;
let link = recover_converged_survival_inverse_link(result, &context, recover)?;
profile_survival_location_scale_with_inverse_link(
data,
spec,
link,
final_wiggle,
kappa_options,
)
.map_err(|err| format!("{context} final profiling failed: {err}"))
}
match spec.inverse_link.clone() {
InverseLink::Sas(state0) => optimize_link_parameters(
data,
spec,
kappa_options,
Array1::from_vec(vec![state0.epsilon, state0.log_delta]),
"SAS",
wiggle.clone(),
wiggle.clone(),
|theta| {
state_from_sasspec(SasLinkSpec {
initial_epsilon: theta[0],
initial_log_delta: theta[1],
})
.map(InverseLink::Sas)
},
|rho| {
state_from_sasspec(SasLinkSpec {
initial_epsilon: rho[0],
initial_log_delta: rho[1],
})
.ok()
.map(InverseLink::Sas)
},
),
InverseLink::BetaLogistic(state0) => optimize_link_parameters(
data,
spec,
kappa_options,
Array1::from_vec(vec![state0.epsilon, state0.log_delta]),
"BetaLogistic",
wiggle.clone(),
wiggle.clone(),
|theta| {
state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: theta[0],
initial_log_delta: theta[1],
})
.map(InverseLink::BetaLogistic)
},
|rho| {
state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: rho[0],
initial_log_delta: rho[1],
})
.ok()
.map(InverseLink::BetaLogistic)
},
),
InverseLink::Mixture(state0) if !state0.rho.is_empty() => {
let components = state0.components.clone();
let components_recover = components.clone();
optimize_link_parameters(
data,
spec,
kappa_options,
state0.rho.clone(),
"mixture",
wiggle.clone(),
wiggle.clone(),
move |rho| {
state_fromspec(&MixtureLinkSpec {
components: components.clone(),
initial_rho: rho.clone(),
})
.map(InverseLink::Mixture)
},
move |rho| {
state_fromspec(&MixtureLinkSpec {
components: components_recover.clone(),
initial_rho: rho.to_owned(),
})
.ok()
.map(InverseLink::Mixture)
},
)
}
_ => profile_survival_location_scale(data, spec.clone(), wiggle, kappa_options),
}
}
let profile = if request.optimize_inverse_link {
optimize_survival_inverse_link_profile(
request.data,
&request.spec,
request.wiggle.clone(),
&request.kappa_options,
)?
} else {
profile_survival_location_scale(
request.data,
request.spec.clone(),
request.wiggle.clone(),
&request.kappa_options,
)?
};
Ok(profile.into_result())
}
pub(crate) fn fit_bernoulli_marginal_slope_model(
request: BernoulliMarginalSlopeFitRequest<'_>,
) -> Result<BernoulliMarginalSlopeFitResult, String> {
fit_bernoulli_marginal_slope_terms(
request.data,
request.spec,
&request.options,
&request.kappa_options,
&request.policy,
)
}
pub(crate) fn fit_survival_marginal_slope_model(
request: SurvivalMarginalSlopeFitRequest<'_>,
) -> Result<SurvivalMarginalSlopeFitResult, String> {
fit_survival_marginal_slope_terms(
request.data,
request.spec,
&request.options,
&request.kappa_options,
)
}
pub(crate) fn fit_latent_survival_model(
request: LatentSurvivalFitRequest<'_>,
) -> Result<LatentSurvivalTermFitResult, String> {
fit_latent_survival_terms(
request.data,
request.spec,
request.frailty,
&request.options,
)
}
pub(crate) fn fit_latent_binary_model(
request: LatentBinaryFitRequest<'_>,
) -> Result<LatentBinaryTermFitResult, String> {
fit_latent_binary_terms(
request.data,
request.spec,
request.frailty,
&request.options,
)
}
pub(crate) fn fit_transformation_normal_model(
request: TransformationNormalFitRequest<'_>,
) -> Result<TransformationNormalFitResult, String> {
fit_transformation_normal(
&request.response,
&request.weights,
&request.offset,
request.data,
&request.covariate_spec,
&request.config,
&request.options,
&request.kappa_options,
request.warm_start.as_ref(),
)
}
fn crossfit_fold_count(n: usize) -> usize {
if n < 250 {
n.min(3).max(2)
} else if n < 200_000 {
5
} else if n < 2_000_000 {
3
} else {
2
}
}
fn crossfit_partition(n: usize, k: usize) -> Vec<Vec<usize>> {
let mut folds: Vec<Vec<usize>> = Vec::with_capacity(k);
let base = n / k;
let remainder = n % k;
let mut start = 0usize;
for f in 0..k {
let len = base + usize::from(f < remainder);
let end = start + len;
folds.push((start..end).collect());
start = end;
}
folds
}
fn crossfit_select_rows_1d(source: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
Array1::from_iter(indices.iter().map(|&i| source[i]))
}
pub(crate) fn crossfit_score_calibration(
data: &Dataset,
col_map: &HashMap<String, usize>,
recipe: Option<&CtnStage1Recipe>,
policy: &crate::resource::ResourcePolicy,
) -> Result<Option<CrossFitScoreCalibration>, String> {
let Some(recipe) = recipe else {
return Ok(None);
};
let n = data.values.nrows();
if n == 0 {
return Err("cross-fit score calibration requires a non-empty dataset".to_string());
}
let y_col = resolve_role_col(col_map, &recipe.response_column, "response")
.map_err(|e| e.to_string())?;
let response_full = data.values.column(y_col).to_owned();
let weights_full = resolve_weight_column(data, col_map, recipe.weight_column.as_deref())
.map_err(|e| e.to_string())?;
let offset_full = resolve_offset_column(data, col_map, recipe.offset_column.as_deref())
.map_err(|e| e.to_string())?;
let parsed_cov = parse_formula(&format!(
"{} ~ {}",
recipe.response_column, recipe.covariate_formula_rhs
))
.map_err(|e| e.to_string())?;
let mut frozen_notes = Vec::new();
let covariate_spec_raw = build_termspec_with_geometry_and_overrides(
&parsed_cov.terms,
data,
col_map,
&mut frozen_notes,
false,
policy,
None,
)
.map_err(|e| e.to_string())?;
let full_cov_design = build_term_collection_design(data.values.view(), &covariate_spec_raw)
.map_err(|e| e.to_string())?;
let frozen_cov_spec =
crate::smooth::freeze_term_collection_from_design(&covariate_spec_raw, &full_cov_design)
.map_err(|e| e.to_string())?;
let p_cov = full_cov_design.design.ncols();
let k = crossfit_fold_count(n);
let folds = crossfit_partition(n, k);
let min_complement = folds.iter().map(|held| n - held.len()).min().unwrap_or(n);
let mut fold_config = recipe.config.clone();
fold_config.response_num_internal_knots =
crate::families::transformation_normal::effective_response_num_internal_knots(
&recipe.config,
min_complement,
p_cov,
response_full.view(),
);
fold_config.response_num_internal_knots_pinned = true;
let mut z_oof = Array1::<f64>::zeros(n);
let mut jac_oof: Option<Array2<f64>> = None;
for held in &folds {
if held.is_empty() {
continue;
}
let held_set: std::collections::HashSet<usize> = held.iter().copied().collect();
let complement: Vec<usize> = (0..n).filter(|i| !held_set.contains(i)).collect();
if complement.is_empty() {
return Err(
"cross-fit fold left an empty training complement; too few rows for K folds"
.to_string(),
);
}
let train_cov = data.values.select(Axis(0), &complement);
let train_resp = crossfit_select_rows_1d(&response_full, &complement);
let train_weights = crossfit_select_rows_1d(&weights_full, &complement);
let train_offset = crossfit_select_rows_1d(&offset_full, &complement);
let fold_fit = fit_transformation_normal(
&train_resp,
&train_weights,
&train_offset,
train_cov.view(),
&frozen_cov_spec,
&fold_config,
&BlockwiseFitOptions::default(),
&SpatialLengthScaleOptimizationOptions::default(),
None,
)?;
let held_cov = data.values.select(Axis(0), held);
let held_resp = crossfit_select_rows_1d(&response_full, held);
let held_offset = crossfit_select_rows_1d(&offset_full, held);
let jac = crate::families::marginal_slope_orthogonal::score_influence_jacobian(
&fold_fit,
&held_resp,
held_cov.view(),
&held_offset,
)?;
if jac.columns.nrows() != held.len() {
return Err(format!(
"cross-fit fold Jacobian row count {} != held-out fold size {}",
jac.columns.nrows(),
held.len()
));
}
if jac.z.len() != held.len() {
return Err(format!(
"cross-fit fold OOF z length {} != held-out fold size {}",
jac.z.len(),
held.len()
));
}
let p1 = jac.columns.ncols();
let jac_full = jac_oof.get_or_insert_with(|| Array2::<f64>::zeros((n, p1)));
if jac_full.ncols() != p1 {
return Err(format!(
"cross-fit fold p₁ mismatch: this fold has {p1} columns but a prior fold had {}; \
the frozen response/covariate basis failed to align across folds",
jac_full.ncols()
));
}
for (local, &global) in held.iter().enumerate() {
z_oof[global] = jac.z[local];
for c in 0..p1 {
jac_full[[global, c]] = jac.columns[[local, c]];
}
}
}
let jac_oof = jac_oof.ok_or_else(|| {
"cross-fit produced no folds with held-out rows; cannot assemble OOF Jacobian".to_string()
})?;
Ok(Some(CrossFitScoreCalibration { z_oof, jac_oof }))
}