1use crate::custom_family::BlockwiseFitOptions;
2use crate::estimate::{EstimationError, FitOptions, FittedLinkState, UnifiedFitResult};
3use crate::families::bernoulli_marginal_slope::{
4 BernoulliMarginalSlopeFitResult, BernoulliMarginalSlopeTermSpec, DeviationBlockConfig,
5 fit_bernoulli_marginal_slope_terms,
6};
7use crate::families::gamlss::{
8 BinomialLocationScaleFitResult, BinomialLocationScaleTermSpec, BlockwiseTermFitResult,
9 BlockwiseTermFitResultParts, GaussianLocationScaleFitResult, GaussianLocationScaleTermSpec,
10 WiggleBlockConfig, fit_binomial_location_scale_terms,
11 fit_binomial_location_scale_terms_with_selected_wiggle,
12 fit_binomial_mean_wiggle_terms_with_selected_basis, fit_gaussian_location_scale_terms,
13 fit_gaussian_location_scale_terms_with_selected_wiggle,
14 select_binomial_location_scale_link_wiggle_basis_from_pilot,
15 select_binomial_mean_link_wiggle_basis_from_pilot,
16 select_gaussian_location_scale_link_wiggle_basis_from_pilot,
17};
18use crate::families::latent_survival::{
19 LatentBinaryTermFitResult, LatentBinaryTermSpec, LatentSurvivalTermFitResult,
20 LatentSurvivalTermSpec, fit_latent_binary_terms, fit_latent_survival_terms,
21 latent_hazard_loading,
22};
23use crate::families::lognormal_kernel::FrailtySpec;
24use crate::families::survival_location_scale::{
25 SurvivalLocationScaleTermFitResult, SurvivalLocationScaleTermSpec,
26 fit_survival_location_scale_terms, fit_survival_location_scale_terms_with_selected_wiggle,
27 select_survival_link_wiggle_basis_from_pilot,
28};
29use crate::families::survival_marginal_slope::{
30 SurvivalMarginalSlopeFitResult, SurvivalMarginalSlopeTermSpec,
31 fit_survival_marginal_slope_terms,
32};
33use crate::families::transformation_normal::{
34 TransformationNormalConfig, TransformationNormalFitResult, TransformationWarmStart,
35 fit_transformation_normal,
36};
37use crate::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
38use crate::smooth::{
39 AdaptiveRegularizationDiagnostics, SpatialLengthScaleOptimizationOptions, TermCollectionDesign,
40 TermCollectionSpec, build_term_collection_design,
41 fit_term_collectionwith_spatial_length_scale_optimization,
42};
43use crate::types::{
44 InverseLink, LatentCLogLogState, LikelihoodFamily, LinkFunction, MixtureLinkSpec, SasLinkSpec,
45 WigglePenaltyConfig,
46};
47use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
48use std::collections::HashMap;
49
50#[derive(Clone, Debug)]
51pub struct LinkWiggleConfig {
52 pub degree: usize,
53 pub num_internal_knots: usize,
54 pub penalty_orders: Vec<usize>,
55 pub double_penalty: bool,
56}
57
58#[derive(Clone, Debug)]
59pub struct StandardBinomialWiggleConfig {
60 pub link_kind: InverseLink,
61 pub wiggle: LinkWiggleConfig,
62}
63
64pub struct StandardFitRequest<'a> {
65 pub data: ArrayView2<'a, f64>,
66 pub y: Array1<f64>,
67 pub weights: Array1<f64>,
68 pub offset: Array1<f64>,
69 pub spec: TermCollectionSpec,
70 pub family: LikelihoodFamily,
71 pub options: FitOptions,
72 pub kappa_options: SpatialLengthScaleOptimizationOptions,
73 pub wiggle: Option<StandardBinomialWiggleConfig>,
74 pub wiggle_options: Option<BlockwiseFitOptions>,
75}
76
77pub struct GaussianLocationScaleFitRequest<'a> {
78 pub data: ArrayView2<'a, f64>,
79 pub spec: GaussianLocationScaleTermSpec,
80 pub wiggle: Option<LinkWiggleConfig>,
81 pub options: BlockwiseFitOptions,
82 pub kappa_options: SpatialLengthScaleOptimizationOptions,
83}
84
85pub struct BinomialLocationScaleFitRequest<'a> {
86 pub data: ArrayView2<'a, f64>,
87 pub spec: BinomialLocationScaleTermSpec,
88 pub wiggle: Option<LinkWiggleConfig>,
89 pub options: BlockwiseFitOptions,
90 pub kappa_options: SpatialLengthScaleOptimizationOptions,
91}
92
93pub struct SurvivalLocationScaleFitRequest<'a> {
94 pub data: ArrayView2<'a, f64>,
95 pub spec: SurvivalLocationScaleTermSpec,
96 pub wiggle: Option<LinkWiggleConfig>,
97 pub kappa_options: SpatialLengthScaleOptimizationOptions,
98 pub optimize_inverse_link: bool,
99}
100
101pub struct SurvivalTransformationFitRequest<'a> {
102 pub data: ArrayView2<'a, f64>,
103 pub spec: SurvivalTransformationTermSpec,
104}
105
106#[derive(Clone)]
107pub struct SurvivalTransformationTermSpec {
108 pub age_entry: Array1<f64>,
109 pub age_exit: Array1<f64>,
110 pub event_target: Array1<u8>,
111 pub weights: Array1<f64>,
112 pub covariate_spec: TermCollectionSpec,
113 pub covariate_offset: Array1<f64>,
114 pub baseline_cfg: crate::families::survival_construction::SurvivalBaselineConfig,
115 pub likelihood_mode: crate::families::survival_construction::SurvivalLikelihoodMode,
116 pub time_anchor: f64,
117 pub time_build: crate::families::survival_construction::SurvivalTimeBuildOutput,
118 pub timewiggle: Option<LinkWiggleFormulaSpec>,
119 pub weibull_seed: Option<(f64, f64)>,
120 pub ridge_lambda: f64,
121}
122
123pub(crate) fn survival_inverse_link_has_free_parameters(link: &InverseLink) -> bool {
124 match link {
125 InverseLink::Sas(_) | InverseLink::BetaLogistic(_) => true,
126 InverseLink::Mixture(state) => !state.rho.is_empty(),
127 InverseLink::LatentCLogLog(_) | InverseLink::Standard(_) => false,
128 }
129}
130
131fn recover_converged_survival_inverse_link<R>(
132 result: crate::solver::outer_strategy::OuterResult,
133 context: &str,
134 recover: R,
135) -> Result<InverseLink, String>
136where
137 R: FnOnce(&Array1<f64>) -> Option<InverseLink>,
138{
139 if !result.converged {
140 return Err(format!(
141 "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={:.3e})",
142 result.iterations, result.final_value, result.final_grad_norm
143 ));
144 }
145 recover(&result.rho).ok_or_else(|| {
146 format!(
147 "{context} produced an invalid inverse-link state at rho={:?}",
148 result.rho.to_vec()
149 )
150 })
151}
152
153pub struct BernoulliMarginalSlopeFitRequest<'a> {
154 pub data: ArrayView2<'a, f64>,
155 pub spec: BernoulliMarginalSlopeTermSpec,
156 pub options: BlockwiseFitOptions,
157 pub kappa_options: SpatialLengthScaleOptimizationOptions,
158 pub policy: crate::resource::ResourcePolicy,
159}
160
161pub struct SurvivalMarginalSlopeFitRequest<'a> {
162 pub data: ArrayView2<'a, f64>,
163 pub spec: SurvivalMarginalSlopeTermSpec,
164 pub options: BlockwiseFitOptions,
165 pub kappa_options: SpatialLengthScaleOptimizationOptions,
166}
167
168pub struct LatentSurvivalFitRequest<'a> {
169 pub data: ArrayView2<'a, f64>,
170 pub spec: LatentSurvivalTermSpec,
171 pub frailty: FrailtySpec,
172 pub options: BlockwiseFitOptions,
173}
174
175pub struct LatentBinaryFitRequest<'a> {
176 pub data: ArrayView2<'a, f64>,
177 pub spec: LatentBinaryTermSpec,
178 pub frailty: FrailtySpec,
179 pub options: BlockwiseFitOptions,
180}
181
182pub struct TransformationNormalFitRequest<'a> {
183 pub data: ArrayView2<'a, f64>,
184 pub response: Array1<f64>,
185 pub weights: Array1<f64>,
186 pub offset: Array1<f64>,
187 pub covariate_spec: TermCollectionSpec,
188 pub config: TransformationNormalConfig,
189 pub options: BlockwiseFitOptions,
190 pub kappa_options: SpatialLengthScaleOptimizationOptions,
191 pub warm_start: Option<TransformationWarmStart>,
192}
193
194pub enum FitRequest<'a> {
195 Standard(StandardFitRequest<'a>),
196 GaussianLocationScale(GaussianLocationScaleFitRequest<'a>),
197 BinomialLocationScale(BinomialLocationScaleFitRequest<'a>),
198 SurvivalLocationScale(SurvivalLocationScaleFitRequest<'a>),
199 SurvivalTransformation(SurvivalTransformationFitRequest<'a>),
200 BernoulliMarginalSlope(BernoulliMarginalSlopeFitRequest<'a>),
201 SurvivalMarginalSlope(SurvivalMarginalSlopeFitRequest<'a>),
202 LatentSurvival(LatentSurvivalFitRequest<'a>),
203 LatentBinary(LatentBinaryFitRequest<'a>),
204 TransformationNormal(TransformationNormalFitRequest<'a>),
205}
206
207pub struct StandardFitResult {
208 pub fit: UnifiedFitResult,
209 pub design: TermCollectionDesign,
210 pub resolvedspec: TermCollectionSpec,
211 pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
212 pub saved_link_state: FittedLinkState,
213 pub wiggle_knots: Option<Array1<f64>>,
214 pub wiggle_degree: Option<usize>,
215}
216
217pub struct SurvivalLocationScaleFitResult {
218 pub fit: SurvivalLocationScaleTermFitResult,
219 pub inverse_link: InverseLink,
220 pub wiggle_knots: Option<Array1<f64>>,
221 pub wiggle_degree: Option<usize>,
222}
223
224pub struct SurvivalTransformationFitResult {
225 pub fit: UnifiedFitResult,
226 pub resolvedspec: TermCollectionSpec,
227 pub baseline_cfg: crate::families::survival_construction::SurvivalBaselineConfig,
228 pub likelihood_mode: crate::families::survival_construction::SurvivalLikelihoodMode,
229 pub time_anchor: f64,
230 pub time_basisname: String,
231 pub time_base_ncols: usize,
232 pub time_degree: Option<usize>,
233 pub time_knots: Option<Vec<f64>>,
234 pub time_keep_cols: Option<Vec<usize>>,
235 pub time_smooth_lambda: Option<f64>,
236 pub baseline_timewiggle: Option<TimeWiggleBlockInput>,
237}
238
239struct SurvivalLocationScaleProfile {
240 fit: SurvivalLocationScaleTermFitResult,
241 inverse_link: InverseLink,
242 wiggle_knots: Option<Array1<f64>>,
243 wiggle_degree: Option<usize>,
244}
245
246impl SurvivalLocationScaleProfile {
247 fn objective(&self) -> f64 {
248 self.fit.fit.reml_score
249 }
250
251 fn into_result(self) -> SurvivalLocationScaleFitResult {
252 SurvivalLocationScaleFitResult {
253 fit: self.fit,
254 inverse_link: self.inverse_link,
255 wiggle_knots: self.wiggle_knots,
256 wiggle_degree: self.wiggle_degree,
257 }
258 }
259}
260
261pub enum FitResult {
262 Standard(StandardFitResult),
263 GaussianLocationScale(GaussianLocationScaleFitResult),
264 BinomialLocationScale(BinomialLocationScaleFitResult),
265 SurvivalLocationScale(SurvivalLocationScaleFitResult),
266 SurvivalTransformation(SurvivalTransformationFitResult),
267 BernoulliMarginalSlope(BernoulliMarginalSlopeFitResult),
268 SurvivalMarginalSlope(SurvivalMarginalSlopeFitResult),
269 LatentSurvival(LatentSurvivalTermFitResult),
270 LatentBinary(LatentBinaryTermFitResult),
271 TransformationNormal(TransformationNormalFitResult),
272}
273
274fn resolved_wiggle_inverse_link(
275 family: LikelihoodFamily,
276 fit: &UnifiedFitResult,
277 fallback: &InverseLink,
278) -> Result<InverseLink, String> {
279 let resolved = match fit.fitted_link_state(family).map_err(|e| e.to_string())? {
280 FittedLinkState::Standard(Some(link)) => InverseLink::Standard(link),
281 FittedLinkState::Standard(None) => fallback.clone(),
282 FittedLinkState::LatentCLogLog { state } => InverseLink::LatentCLogLog(state),
283 FittedLinkState::Sas { state, .. } => InverseLink::Sas(state),
284 FittedLinkState::BetaLogistic { state, .. } => InverseLink::BetaLogistic(state),
285 FittedLinkState::Mixture { state, .. } => InverseLink::Mixture(state),
286 };
287 require_inverse_link_supports_joint_wiggle(&resolved, "standard link wiggle")?;
288 Ok(resolved)
289}
290
291fn deviation_block_config_from_formula_linkwiggle(
292 wiggle: &LinkWiggleFormulaSpec,
293) -> DeviationBlockConfig {
294 let defaults = WigglePenaltyConfig::cubic_triple_operator_default();
295 DeviationBlockConfig {
296 degree: wiggle.degree,
297 num_internal_knots: wiggle.num_internal_knots,
298 penalty_order: *wiggle.penalty_orders.iter().max().unwrap_or(&2),
299 penalty_orders: wiggle.penalty_orders.clone(),
300 double_penalty: wiggle.double_penalty,
301 monotonicity_eps: defaults.monotonicity_eps,
302 }
303}
304
305struct MarginalSlopeDeviationRouting {
306 score_warp: Option<DeviationBlockConfig>,
307 link_dev: Option<DeviationBlockConfig>,
308}
309
310fn route_marginal_slope_deviation_blocks(
311 main_linkwiggle: Option<&LinkWiggleFormulaSpec>,
312 logslope_linkwiggle: Option<&LinkWiggleFormulaSpec>,
313) -> Result<MarginalSlopeDeviationRouting, String> {
314 Ok(MarginalSlopeDeviationRouting {
315 score_warp: logslope_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
316 link_dev: main_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
317 })
318}
319
320fn fixed_gaussian_shift_frailty_from_spec(
321 frailty: &FrailtySpec,
322 context: &str,
323) -> Result<FrailtySpec, String> {
324 match frailty {
325 FrailtySpec::None => Ok(FrailtySpec::None),
326 FrailtySpec::GaussianShift {
327 sigma_fixed: Some(sigma),
328 } => Ok(FrailtySpec::GaussianShift {
329 sigma_fixed: Some(*sigma),
330 }),
331 FrailtySpec::GaussianShift { sigma_fixed: None } => Err(format!(
332 "{context} currently requires a fixed GaussianShift sigma"
333 )),
334 FrailtySpec::HazardMultiplier { .. } => Err(format!(
335 "{context} requires FrailtySpec::GaussianShift or no frailty"
336 )),
337 }
338}
339
340fn fit_standard_model(request: StandardFitRequest<'_>) -> Result<StandardFitResult, String> {
341 let fitted = fit_term_collectionwith_spatial_length_scale_optimization(
342 request.data,
343 request.y.clone(),
344 request.weights.clone(),
345 request.offset.clone(),
346 &request.spec,
347 request.family,
348 &request.options,
349 &request.kappa_options,
350 )
351 .map_err(|e| e.to_string())?;
352
353 let result = StandardFitResult {
354 saved_link_state: fitted.fit.fitted_link.clone(),
355 fit: fitted.fit,
356 design: fitted.design,
357 resolvedspec: fitted.resolvedspec,
358 adaptive_diagnostics: fitted.adaptive_diagnostics,
359 wiggle_knots: None,
360 wiggle_degree: None,
361 };
362
363 let Some(wiggle) = request.wiggle else {
364 return Ok(result);
365 };
366 let wiggle_options = request
367 .wiggle_options
368 .ok_or_else(|| "standard wiggle workflow requires blockwise wiggle options".to_string())?;
369 let wiggle_link_kind =
370 resolved_wiggle_inverse_link(request.family, &result.fit, &wiggle.link_kind)?;
371 let selected_wiggle_basis = select_binomial_mean_link_wiggle_basis_from_pilot(
372 &result.design,
373 &result.fit,
374 &WiggleBlockConfig {
375 degree: wiggle.wiggle.degree,
376 num_internal_knots: wiggle.wiggle.num_internal_knots,
377 penalty_order: 2,
378 double_penalty: wiggle.wiggle.double_penalty,
379 },
380 &wiggle.wiggle.penalty_orders,
381 )?;
382
383 let solved = fit_binomial_mean_wiggle_terms_with_selected_basis(
384 request.data,
385 &result.resolvedspec,
386 &result.design,
387 &result.fit,
388 &request.y,
389 &request.weights,
390 wiggle_link_kind,
391 selected_wiggle_basis,
392 &wiggle_options,
393 &request.kappa_options,
394 )?;
395
396 Ok(StandardFitResult {
397 saved_link_state: result.saved_link_state,
398 fit: solved.fit,
399 design: solved.design,
400 resolvedspec: solved.resolvedspec,
401 adaptive_diagnostics: result.adaptive_diagnostics,
402 wiggle_knots: Some(solved.wiggle_knots),
403 wiggle_degree: Some(solved.wiggle_degree),
404 })
405}
406
407fn fit_gaussian_location_scale_model(
408 request: GaussianLocationScaleFitRequest<'_>,
409) -> Result<GaussianLocationScaleFitResult, String> {
410 if let Some(wiggle_cfg) = request.wiggle {
411 let pilot = fit_gaussian_location_scale_terms(
412 request.data,
413 GaussianLocationScaleTermSpec {
414 y: request.spec.y.clone(),
415 weights: request.spec.weights.clone(),
416 meanspec: request.spec.meanspec.clone(),
417 log_sigmaspec: request.spec.log_sigmaspec.clone(),
418 mean_offset: request.spec.mean_offset.clone(),
419 log_sigma_offset: request.spec.log_sigma_offset.clone(),
420 },
421 &request.options,
422 &request.kappa_options,
423 )?;
424 let selected_wiggle_basis = select_gaussian_location_scale_link_wiggle_basis_from_pilot(
425 &pilot,
426 &WiggleBlockConfig {
427 degree: wiggle_cfg.degree,
428 num_internal_knots: wiggle_cfg.num_internal_knots,
429 penalty_order: 2,
430 double_penalty: wiggle_cfg.double_penalty,
431 },
432 &wiggle_cfg.penalty_orders,
433 )?;
434 let solved = fit_gaussian_location_scale_terms_with_selected_wiggle(
435 request.data,
436 request.spec,
437 selected_wiggle_basis,
438 &request.options,
439 &request.kappa_options,
440 )?;
441 let fit = solved.fit.fit;
442 let beta_link_wiggle = fit.block_states.get(2).map(|b| b.beta.to_vec());
443 Ok(GaussianLocationScaleFitResult {
444 fit: BlockwiseTermFitResult::try_from_parts(BlockwiseTermFitResultParts {
445 fit,
446 meanspec_resolved: solved.fit.meanspec_resolved,
447 noisespec_resolved: solved.fit.noisespec_resolved,
448 mean_design: solved.fit.mean_design,
449 noise_design: solved.fit.noise_design,
450 })?,
451 wiggle_knots: Some(solved.wiggle_knots),
452 wiggle_degree: Some(solved.wiggle_degree),
453 beta_link_wiggle,
454 })
455 } else {
456 let fit = fit_gaussian_location_scale_terms(
457 request.data,
458 request.spec,
459 &request.options,
460 &request.kappa_options,
461 )?;
462 Ok(GaussianLocationScaleFitResult {
463 fit,
464 wiggle_knots: None,
465 wiggle_degree: None,
466 beta_link_wiggle: None,
467 })
468 }
469}
470
471fn fit_binomial_location_scale_model(
472 request: BinomialLocationScaleFitRequest<'_>,
473) -> Result<BinomialLocationScaleFitResult, String> {
474 if let Some(wiggle_cfg) = request.wiggle {
475 require_inverse_link_supports_joint_wiggle(
476 &request.spec.link_kind,
477 "binomial location-scale link wiggle",
478 )?;
479 let pilot = fit_binomial_location_scale_terms(
480 request.data,
481 BinomialLocationScaleTermSpec {
482 y: request.spec.y.clone(),
483 weights: request.spec.weights.clone(),
484 link_kind: request.spec.link_kind.clone(),
485 thresholdspec: request.spec.thresholdspec.clone(),
486 log_sigmaspec: request.spec.log_sigmaspec.clone(),
487 threshold_offset: request.spec.threshold_offset.clone(),
488 log_sigma_offset: request.spec.log_sigma_offset.clone(),
489 },
490 &request.options,
491 &request.kappa_options,
492 )?;
493 let selected_wiggle_basis = select_binomial_location_scale_link_wiggle_basis_from_pilot(
494 &pilot,
495 &WiggleBlockConfig {
496 degree: wiggle_cfg.degree,
497 num_internal_knots: wiggle_cfg.num_internal_knots,
498 penalty_order: 2,
499 double_penalty: wiggle_cfg.double_penalty,
500 },
501 &wiggle_cfg.penalty_orders,
502 )?;
503 let solved = fit_binomial_location_scale_terms_with_selected_wiggle(
504 request.data,
505 request.spec,
506 selected_wiggle_basis,
507 &request.options,
508 &request.kappa_options,
509 )?;
510 let fit = solved.fit.fit;
511 let beta_link_wiggle = fit.block_states.get(2).map(|b| b.beta.to_vec());
512 Ok(BinomialLocationScaleFitResult {
513 fit: BlockwiseTermFitResult::try_from_parts(BlockwiseTermFitResultParts {
514 fit,
515 meanspec_resolved: solved.fit.meanspec_resolved,
516 noisespec_resolved: solved.fit.noisespec_resolved,
517 mean_design: solved.fit.mean_design,
518 noise_design: solved.fit.noise_design,
519 })?,
520 wiggle_knots: Some(solved.wiggle_knots),
521 wiggle_degree: Some(solved.wiggle_degree),
522 beta_link_wiggle,
523 })
524 } else {
525 let solved = fit_binomial_location_scale_terms(
526 request.data,
527 request.spec,
528 &request.options,
529 &request.kappa_options,
530 )?;
531 Ok(BinomialLocationScaleFitResult {
532 fit: solved,
533 wiggle_knots: None,
534 wiggle_degree: None,
535 beta_link_wiggle: None,
536 })
537 }
538}
539
540fn survival_working_reml_score(state: &crate::pirls::WorkingState) -> f64 {
541 0.5 * (state.deviance + state.penalty_term)
542}
543
544fn fitted_weibull_baseline_from_linear_time_beta(
545 beta: &Array1<f64>,
546) -> Option<crate::families::survival_construction::SurvivalBaselineConfig> {
547 if beta.len() < 2 {
548 return None;
549 }
550 let shape = beta[1];
551 if !shape.is_finite() || shape <= 0.0 {
552 return None;
553 }
554 let scale = (-beta[0] / shape).exp();
555 if !scale.is_finite() || scale <= 0.0 {
556 return None;
557 }
558 Some(
559 crate::families::survival_construction::SurvivalBaselineConfig {
560 target: SurvivalBaselineTarget::Weibull,
561 scale: Some(scale),
562 shape: Some(shape),
563 rate: None,
564 makeham: None,
565 },
566 )
567}
568
569fn survival_unified_fit_result(
570 beta: Array1<f64>,
571 lambdas: Array1<f64>,
572 summary: &crate::pirls::WorkingModelPirlsResult,
573 state: &crate::pirls::WorkingState,
574) -> Result<UnifiedFitResult, String> {
575 let log_lambdas = lambdas.mapv(|v| v.max(1e-300).ln());
576 let reml_score = survival_working_reml_score(state);
577 crate::estimate::validate_all_finite("survival fit beta", beta.iter().copied())?;
578 crate::estimate::validate_all_finite("survival fit lambdas", lambdas.iter().copied())?;
579 crate::estimate::ensure_finite_scalar("survival fit log_likelihood", state.log_likelihood)?;
580 crate::estimate::ensure_finite_scalar("survival fit deviance", state.deviance)?;
581 crate::estimate::ensure_finite_scalar("survival fit penalty", state.penalty_term)?;
582 crate::estimate::ensure_finite_scalar("survival fit reml_score", reml_score)?;
583 crate::estimate::ensure_finite_scalar("survival fit gradient_norm", summary.lastgradient_norm)?;
584 crate::estimate::ensure_finite_scalar("survival fit max_abs_eta", summary.max_abs_eta)?;
585
586 UnifiedFitResult::try_from_parts(crate::estimate::UnifiedFitResultParts {
587 blocks: vec![crate::estimate::FittedBlock {
588 beta: beta.clone(),
589 role: crate::estimate::BlockRole::Mean,
590 edf: 0.0,
591 lambdas: lambdas.clone(),
592 }],
593 log_lambdas,
594 lambdas,
595 likelihood_family: Some(LikelihoodFamily::RoystonParmar),
596 likelihood_scale: crate::types::LikelihoodScaleMetadata::Unspecified,
597 log_likelihood_normalization: crate::types::LogLikelihoodNormalization::UserProvided,
598 log_likelihood: state.log_likelihood,
599 deviance: state.deviance,
600 reml_score,
601 stable_penalty_term: state.penalty_term,
602 penalized_objective: reml_score,
603 outer_iterations: summary.iterations,
604 outer_converged: true,
605 outer_gradient_norm: summary.lastgradient_norm,
606 standard_deviation: 1.0,
607 covariance_conditional: None,
608 covariance_corrected: None,
609 inference: None,
610 fitted_link: FittedLinkState::Standard(None),
611 geometry: None,
612 block_states: Vec::new(),
613 pirls_status: summary.status,
614 max_abs_eta: summary.max_abs_eta,
615 constraint_kkt: None,
616 artifacts: crate::estimate::FitArtifacts {
617 pirls: None,
618 ..Default::default()
619 },
620 inner_cycles: 0,
621 })
622 .map_err(|err| err.to_string())
623}
624
625fn fit_survival_transformation_model(
626 request: SurvivalTransformationFitRequest<'_>,
627) -> Result<SurvivalTransformationFitResult, String> {
628 use crate::survival::{MonotonicityPenalty, PenaltyBlock, PenaltyBlocks, SurvivalSpec};
629
630 let SurvivalTransformationFitRequest { data, spec } = request;
631 let mut baseline_cfg = spec.baseline_cfg.clone();
632 let covariate_design =
633 build_term_collection_design(data, &spec.covariate_spec).map_err(|err| err.to_string())?;
634 let resolvedspec =
635 crate::smooth::freeze_term_collection_from_design(&spec.covariate_spec, &covariate_design)
636 .map_err(|err| err.to_string())?;
637 let dense_cov_design = covariate_design.design.to_dense();
638 let p_cov = dense_cov_design.ncols();
639 let event_competing = Array1::<u8>::zeros(spec.event_target.len());
640 let exact_derivative_guard = survival_derivative_guard_for_likelihood(spec.likelihood_mode);
641
642 let build_working_model =
643 |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
644 let prepared = prepare_workflow_survival_time_stack(
645 &spec.age_entry,
646 &spec.age_exit,
647 candidate,
648 spec.likelihood_mode,
649 None,
650 spec.time_anchor,
651 exact_derivative_guard,
652 &spec.time_build,
653 spec.timewiggle.as_ref(),
654 None,
655 )?;
656 let mut eta_offset_entry = prepared.eta_offset_entry.clone();
657 let mut eta_offset_exit = prepared.eta_offset_exit.clone();
658 eta_offset_entry += &spec.covariate_offset;
659 eta_offset_exit += &spec.covariate_offset;
660 let p_time_total = prepared.time_design_exit.ncols();
661 let p = p_time_total + p_cov;
662 let mut penalty_blocks = Vec::<PenaltyBlock>::new();
663 for (idx, penalty) in prepared.time_penalties.iter().enumerate() {
664 if penalty.nrows() == p_time_total && penalty.ncols() == p_time_total {
665 penalty_blocks.push(PenaltyBlock {
666 matrix: penalty.clone(),
667 lambda: spec.time_build.smooth_lambda.unwrap_or(1e-2),
668 range: 0..p_time_total,
669 nullspace_dim: prepared.time_nullspace_dims.get(idx).copied().unwrap_or(0),
670 });
671 }
672 }
673 let ridge_range_start = if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull
674 && spec.time_build.basisname == "linear"
675 && spec.timewiggle.is_none()
676 {
677 1
678 } else {
679 0
680 };
681 if spec.ridge_lambda > 0.0 && p > ridge_range_start {
682 let dim = p - ridge_range_start;
683 let mut ridge = Array2::<f64>::zeros((dim, dim));
684 for d in 0..dim {
685 ridge[[d, d]] = 1.0;
686 }
687 penalty_blocks.push(PenaltyBlock {
688 matrix: ridge,
689 lambda: spec.ridge_lambda,
690 range: ridge_range_start..p,
691 nullspace_dim: 0,
692 });
693 }
694 let dense_time_entry = prepared.time_design_entry.to_dense();
695 let dense_time_exit = prepared.time_design_exit.to_dense();
696 let dense_time_derivative = prepared.time_design_derivative.to_dense();
697 let mut model =
698 crate::families::royston_parmar::working_model_from_time_covariateshared(
699 PenaltyBlocks::new(penalty_blocks.clone()),
700 MonotonicityPenalty { tolerance: 0.0 },
701 SurvivalSpec::Net,
702 crate::families::royston_parmar::RoystonParmarSharedTimeCovariateInputs {
703 age_entry: spec.age_entry.view(),
704 age_exit: spec.age_exit.view(),
705 event_target: spec.event_target.view(),
706 event_competing: event_competing.view(),
707 weights: spec.weights.view(),
708 time_entry: dense_time_entry.view(),
709 time_exit: dense_time_exit.view(),
710 time_derivative: dense_time_derivative.view(),
711 covariates: dense_cov_design.view(),
712 monotonicity_constraint_rows: None,
713 monotonicity_constraint_offsets: None,
714 eta_offset_entry: Some(eta_offset_entry.view()),
715 eta_offset_exit: Some(eta_offset_exit.view()),
716 derivative_offset_exit: Some(prepared.derivative_offset_exit.view()),
717 },
718 )
719 .map_err(|err| format!("failed to construct survival model: {err}"))?;
720 if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull {
721 model
722 .set_structural_monotonicity(true, p_time_total)
723 .map_err(|err| format!("failed to enable structural monotonicity: {err}"))?;
724 }
725 let mut beta0 = Array1::<f64>::zeros(p);
726 if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none()
727 {
728 let (scale, shape) = spec
729 .weibull_seed
730 .ok_or_else(|| "weibull survival fit missing scale/shape seed".to_string())?;
731 if p_time_total < 2 {
732 return Err(format!(
733 "weibull built-in time basis has {p_time_total} columns but needs 2 to seed scale/shape"
734 ));
735 }
736 beta0[0] = -shape * scale.ln();
737 beta0[1] = shape;
738 }
739 let structural_lower_bounds =
740 if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull && p_time_total > 0 {
741 let mut lb = Array1::from_elem(p, f64::NEG_INFINITY);
742 for j in 0..p_time_total {
743 lb[j] = 0.0;
744 beta0[j] = 1e-4;
745 }
746 Some(lb)
747 } else {
748 None
749 };
750 Ok::<_, String>((
751 prepared,
752 penalty_blocks,
753 beta0,
754 structural_lower_bounds,
755 model,
756 ))
757 };
758
759 if baseline_cfg.target != SurvivalBaselineTarget::Linear {
760 baseline_cfg = optimize_survival_baseline_config(
761 &baseline_cfg,
762 "workflow survival transformation baseline",
763 |candidate| {
764 let (_, _, beta0, structural_lower_bounds, mut model) =
765 build_working_model(candidate)?;
766 let opts = crate::pirls::WorkingModelPirlsOptions {
767 max_iterations: 400,
768 convergence_tolerance: 1e-6,
769 max_step_halving: 40,
770 min_step_size: 1e-12,
771 firth_bias_reduction: false,
772 coefficient_lower_bounds: structural_lower_bounds,
773 linear_constraints: None,
774 initial_lm_lambda: None,
775 };
776 let summary = crate::pirls::runworking_model_pirls(
777 &mut model,
778 crate::types::Coefficients::new(beta0),
779 &opts,
780 |_| {},
781 )
782 .map_err(|err| format!("survival PIRLS failed: {err}"))?;
783 let beta = summary.beta.as_ref().to_owned();
784 let state = model.update_state(&beta).map_err(|err| {
785 format!("failed to evaluate survival baseline candidate: {err}")
786 })?;
787 Ok(survival_working_reml_score(&state))
788 },
789 )?;
790 }
791
792 let (prepared, penalty_blocks, beta0, structural_lower_bounds, mut model) =
793 build_working_model(&baseline_cfg)?;
794 let opts = crate::pirls::WorkingModelPirlsOptions {
795 max_iterations: 400,
796 convergence_tolerance: 1e-6,
797 max_step_halving: 40,
798 min_step_size: 1e-12,
799 firth_bias_reduction: false,
800 coefficient_lower_bounds: structural_lower_bounds,
801 linear_constraints: None,
802 initial_lm_lambda: None,
803 };
804 let summary = crate::pirls::runworking_model_pirls(
805 &mut model,
806 crate::types::Coefficients::new(beta0),
807 &opts,
808 |_| {},
809 )
810 .map_err(|err| format!("survival PIRLS failed: {err}"))?;
811 match summary.status {
812 crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum => {
813 }
814 ref other => {
815 return Err(format!(
816 "survival PIRLS did not converge: status={other:?}, grad_norm={:.3e}, iterations={}, deviance={:.6e}",
817 summary.lastgradient_norm, summary.iterations, summary.state.deviance
818 ));
819 }
820 }
821 let beta = summary.beta.as_ref().to_owned();
822 let state = model
823 .update_state(&beta)
824 .map_err(|err| format!("failed to evaluate survival optimum: {err}"))?;
825 let lambdas = Array1::from_iter(penalty_blocks.iter().map(|block| block.lambda));
826 let fitted_baseline_cfg =
827 if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none() {
828 let time_beta = beta
829 .slice(s![..spec.time_build.x_exit_time.ncols()])
830 .to_owned();
831 fitted_weibull_baseline_from_linear_time_beta(&time_beta).ok_or_else(|| {
832 "failed to recover fitted Weibull scale/shape from the linear time coefficients"
833 .to_string()
834 })?
835 } else {
836 baseline_cfg
837 };
838 let fit = survival_unified_fit_result(beta, lambdas, &summary, &state)?;
839
840 Ok(SurvivalTransformationFitResult {
841 fit,
842 resolvedspec,
843 baseline_cfg: fitted_baseline_cfg,
844 likelihood_mode: spec.likelihood_mode,
845 time_anchor: spec.time_anchor,
846 time_basisname: spec.time_build.basisname.clone(),
847 time_base_ncols: spec.time_build.x_exit_time.ncols(),
848 time_degree: spec.time_build.degree,
849 time_knots: spec.time_build.knots.clone(),
850 time_keep_cols: spec.time_build.keep_cols.clone(),
851 time_smooth_lambda: spec.time_build.smooth_lambda,
852 baseline_timewiggle: prepared.timewiggle_block,
853 })
854}
855
856fn fit_survival_location_scale_model(
857 request: SurvivalLocationScaleFitRequest<'_>,
858) -> Result<SurvivalLocationScaleFitResult, String> {
859 fn profile_survival_location_scale(
863 data: ArrayView2<'_, f64>,
864 spec: SurvivalLocationScaleTermSpec,
865 wiggle: Option<LinkWiggleConfig>,
866 kappa_options: &SpatialLengthScaleOptimizationOptions,
867 ) -> Result<SurvivalLocationScaleProfile, String> {
868 let mut wiggle_knots = None;
869 let mut wiggle_degree = None;
870 let inverse_link = spec.inverse_link.clone();
871
872 let fit = if let Some(wiggle) = wiggle {
873 require_inverse_link_supports_joint_wiggle(&inverse_link, "survival link wiggle")?;
874 let mut pilot_spec = spec.clone();
875 pilot_spec.linkwiggle_block = None;
876 let pilot = fit_survival_location_scale_terms(data, pilot_spec, kappa_options)?;
877 let selected_wiggle_basis = select_survival_link_wiggle_basis_from_pilot(
878 &pilot,
879 &WiggleBlockConfig {
880 degree: wiggle.degree,
881 num_internal_knots: wiggle.num_internal_knots,
882 penalty_order: 2,
883 double_penalty: wiggle.double_penalty,
884 },
885 &wiggle.penalty_orders,
886 )?;
887 wiggle_knots = Some(selected_wiggle_basis.knots.clone());
888 wiggle_degree = Some(selected_wiggle_basis.degree);
889 fit_survival_location_scale_terms_with_selected_wiggle(
890 data,
891 spec,
892 selected_wiggle_basis,
893 kappa_options,
894 )?
895 } else {
896 fit_survival_location_scale_terms(data, spec, kappa_options)?
897 };
898
899 Ok(SurvivalLocationScaleProfile {
900 fit,
901 inverse_link,
902 wiggle_knots,
903 wiggle_degree,
904 })
905 }
906
907 fn profile_survival_location_scale_with_inverse_link(
908 data: ArrayView2<'_, f64>,
909 spec: &SurvivalLocationScaleTermSpec,
910 inverse_link: InverseLink,
911 wiggle: Option<LinkWiggleConfig>,
912 kappa_options: &SpatialLengthScaleOptimizationOptions,
913 ) -> Result<SurvivalLocationScaleProfile, String> {
914 let mut spec_at_link = spec.clone();
915 spec_at_link.inverse_link = inverse_link;
916 profile_survival_location_scale(data, spec_at_link, wiggle, kappa_options)
917 }
918
919 fn optimize_survival_inverse_link_profile(
920 data: ArrayView2<'_, f64>,
921 spec: &SurvivalLocationScaleTermSpec,
922 wiggle: Option<LinkWiggleConfig>,
923 kappa_options: &SpatialLengthScaleOptimizationOptions,
924 ) -> Result<SurvivalLocationScaleProfile, String> {
925 fn optimize_link_parameters<F, R>(
926 data: ArrayView2<'_, f64>,
927 spec: &SurvivalLocationScaleTermSpec,
928 kappa_options: &SpatialLengthScaleOptimizationOptions,
929 init: Array1<f64>,
930 name: &str,
931 final_wiggle: Option<LinkWiggleConfig>,
932 objective: F,
933 recover: R,
934 ) -> Result<SurvivalLocationScaleProfile, String>
935 where
936 F: FnMut(&Array1<f64>) -> Result<f64, EstimationError>,
937 R: Fn(&Array1<f64>) -> Option<InverseLink>,
938 {
939 use crate::solver::outer_strategy::{OuterProblem, SolverClass};
940 let dim = init.len();
941 let lower = init.mapv(|v| v - 6.0);
948 let upper = init.mapv(|v| v + 6.0);
949 let problem = OuterProblem::new(dim)
950 .with_solver_class(SolverClass::AuxiliaryGradientFree)
951 .with_tolerance(1e-4)
952 .with_max_iter(240)
953 .with_bounds(lower, upper)
954 .with_heuristic_lambdas(init.to_vec());
955 let context = format!("survival inverse-link optimization ({name}, dim={dim})");
956 let mut obj = problem.build_objective(
957 objective,
958 |f: &mut F, rho: &ndarray::Array1<f64>| f(rho),
959 |_: &mut F, _: &ndarray::Array1<f64>| {
960 Err(EstimationError::InvalidInput(
961 "inverse-link aux optimizer: CompassSearch dispatch only \
962 calls eval_cost; eval(gradient) is unreachable by \
963 construction"
964 .to_string(),
965 ))
966 },
967 None::<fn(&mut F)>,
968 None::<
969 fn(
970 &mut F,
971 &ndarray::Array1<f64>,
972 )
973 -> Result<crate::solver::outer_strategy::EfsEval, EstimationError>,
974 >,
975 );
976 let result = problem
977 .run(&mut obj, &context)
978 .map_err(|err| format!("{context} failed: {err}"))?;
979 let link = recover_converged_survival_inverse_link(result, &context, recover)?;
980 profile_survival_location_scale_with_inverse_link(
981 data,
982 spec,
983 link,
984 final_wiggle,
985 kappa_options,
986 )
987 .map_err(|err| format!("{context} final profiling failed: {err}"))
988 }
989
990 match spec.inverse_link.clone() {
991 InverseLink::Sas(state0) => {
992 let init = Array1::from_vec(vec![state0.epsilon, state0.log_delta]);
993 let wiggle_cfg = wiggle.clone();
994 optimize_link_parameters(
995 data,
996 spec,
997 kappa_options,
998 init,
999 "SAS",
1000 wiggle.clone(),
1001 |theta: &Array1<f64>| {
1002 let state = state_from_sasspec(SasLinkSpec {
1003 initial_epsilon: theta[0],
1004 initial_log_delta: theta[1],
1005 })
1006 .map_err(EstimationError::InvalidInput)?;
1007 Ok(profile_survival_location_scale_with_inverse_link(
1008 data,
1009 spec,
1010 InverseLink::Sas(state),
1011 wiggle_cfg.clone(),
1012 kappa_options,
1013 )
1014 .map_err(EstimationError::InvalidInput)?
1015 .objective())
1016 },
1017 |rho| {
1018 state_from_sasspec(SasLinkSpec {
1019 initial_epsilon: rho[0],
1020 initial_log_delta: rho[1],
1021 })
1022 .ok()
1023 .map(InverseLink::Sas)
1024 },
1025 )
1026 }
1027 InverseLink::BetaLogistic(state0) => {
1028 let init = Array1::from_vec(vec![state0.epsilon, state0.log_delta]);
1029 let wiggle_cfg = wiggle.clone();
1030 optimize_link_parameters(
1031 data,
1032 spec,
1033 kappa_options,
1034 init,
1035 "BetaLogistic",
1036 wiggle.clone(),
1037 |theta: &Array1<f64>| {
1038 let state = state_from_beta_logisticspec(SasLinkSpec {
1039 initial_epsilon: theta[0],
1040 initial_log_delta: theta[1],
1041 })
1042 .map_err(EstimationError::InvalidInput)?;
1043 Ok(profile_survival_location_scale_with_inverse_link(
1044 data,
1045 spec,
1046 InverseLink::BetaLogistic(state),
1047 wiggle_cfg.clone(),
1048 kappa_options,
1049 )
1050 .map_err(EstimationError::InvalidInput)?
1051 .objective())
1052 },
1053 |rho| {
1054 state_from_beta_logisticspec(SasLinkSpec {
1055 initial_epsilon: rho[0],
1056 initial_log_delta: rho[1],
1057 })
1058 .ok()
1059 .map(InverseLink::BetaLogistic)
1060 },
1061 )
1062 }
1063 InverseLink::Mixture(state0) if !state0.rho.is_empty() => {
1064 let components = state0.components.clone();
1065 let components_recover = components.clone();
1066 let wiggle_cfg = wiggle.clone();
1067 optimize_link_parameters(
1068 data,
1069 spec,
1070 kappa_options,
1071 state0.rho.clone(),
1072 "mixture",
1073 wiggle.clone(),
1074 move |rho: &Array1<f64>| {
1075 let state = state_fromspec(&MixtureLinkSpec {
1076 components: components.clone(),
1077 initial_rho: rho.clone(),
1078 })
1079 .map_err(EstimationError::InvalidInput)?;
1080 Ok(profile_survival_location_scale_with_inverse_link(
1081 data,
1082 spec,
1083 InverseLink::Mixture(state),
1084 wiggle_cfg.clone(),
1085 kappa_options,
1086 )
1087 .map_err(EstimationError::InvalidInput)?
1088 .objective())
1089 },
1090 move |rho| {
1091 state_fromspec(&MixtureLinkSpec {
1092 components: components_recover.clone(),
1093 initial_rho: rho.to_owned(),
1094 })
1095 .ok()
1096 .map(InverseLink::Mixture)
1097 },
1098 )
1099 }
1100 _ => profile_survival_location_scale(data, spec.clone(), wiggle, kappa_options),
1101 }
1102 }
1103
1104 let profile = if request.optimize_inverse_link {
1105 optimize_survival_inverse_link_profile(
1106 request.data,
1107 &request.spec,
1108 request.wiggle.clone(),
1109 &request.kappa_options,
1110 )?
1111 } else {
1112 profile_survival_location_scale(
1113 request.data,
1114 request.spec.clone(),
1115 request.wiggle.clone(),
1116 &request.kappa_options,
1117 )?
1118 };
1119
1120 Ok(profile.into_result())
1121}
1122
1123fn fit_bernoulli_marginal_slope_model(
1124 request: BernoulliMarginalSlopeFitRequest<'_>,
1125) -> Result<BernoulliMarginalSlopeFitResult, String> {
1126 let mut options = request.options.clone();
1132 crate::families::marginal_slope_shared::inject_biobank_outer_subsample_from_arrays(
1133 &mut options,
1134 request.spec.z.as_slice().expect("z is contiguous"),
1135 request.spec.y.as_slice().expect("y is contiguous"),
1136 );
1137 fit_bernoulli_marginal_slope_terms(
1138 request.data,
1139 request.spec,
1140 &options,
1141 &request.kappa_options,
1142 &request.policy,
1143 )
1144}
1145
1146fn fit_survival_marginal_slope_model(
1147 request: SurvivalMarginalSlopeFitRequest<'_>,
1148) -> Result<SurvivalMarginalSlopeFitResult, String> {
1149 let mut options = request.options.clone();
1153 crate::families::marginal_slope_shared::inject_biobank_outer_subsample_from_arrays(
1154 &mut options,
1155 request.spec.z.as_slice().expect("z is contiguous"),
1156 request
1157 .spec
1158 .event_target
1159 .as_slice()
1160 .expect("event_target is contiguous"),
1161 );
1162 fit_survival_marginal_slope_terms(request.data, request.spec, &options, &request.kappa_options)
1163}
1164
1165fn fit_latent_survival_model(
1166 request: LatentSurvivalFitRequest<'_>,
1167) -> Result<LatentSurvivalTermFitResult, String> {
1168 fit_latent_survival_terms(
1169 request.data,
1170 request.spec,
1171 request.frailty,
1172 &request.options,
1173 )
1174}
1175
1176fn fit_latent_binary_model(
1177 request: LatentBinaryFitRequest<'_>,
1178) -> Result<LatentBinaryTermFitResult, String> {
1179 fit_latent_binary_terms(
1180 request.data,
1181 request.spec,
1182 request.frailty,
1183 &request.options,
1184 )
1185}
1186
1187fn fit_transformation_normal_model(
1188 request: TransformationNormalFitRequest<'_>,
1189) -> Result<TransformationNormalFitResult, String> {
1190 fit_transformation_normal(
1191 &request.response,
1192 &request.weights,
1193 &request.offset,
1194 request.data,
1195 &request.covariate_spec,
1196 &request.config,
1197 &request.options,
1198 &request.kappa_options,
1199 request.warm_start.as_ref(),
1200 )
1201}
1202
1203pub fn fit_model(request: FitRequest<'_>) -> Result<FitResult, String> {
1204 match request {
1205 FitRequest::Standard(request) => fit_standard_model(request).map(FitResult::Standard),
1206 FitRequest::GaussianLocationScale(request) => {
1207 fit_gaussian_location_scale_model(request).map(FitResult::GaussianLocationScale)
1208 }
1209 FitRequest::BinomialLocationScale(request) => {
1210 fit_binomial_location_scale_model(request).map(FitResult::BinomialLocationScale)
1211 }
1212 FitRequest::SurvivalLocationScale(request) => {
1213 fit_survival_location_scale_model(request).map(FitResult::SurvivalLocationScale)
1214 }
1215 FitRequest::SurvivalTransformation(request) => {
1216 fit_survival_transformation_model(request).map(FitResult::SurvivalTransformation)
1217 }
1218 FitRequest::BernoulliMarginalSlope(request) => {
1219 fit_bernoulli_marginal_slope_model(request).map(FitResult::BernoulliMarginalSlope)
1220 }
1221 FitRequest::SurvivalMarginalSlope(request) => {
1222 fit_survival_marginal_slope_model(request).map(FitResult::SurvivalMarginalSlope)
1223 }
1224 FitRequest::LatentSurvival(request) => {
1225 fit_latent_survival_model(request).map(FitResult::LatentSurvival)
1226 }
1227 FitRequest::LatentBinary(request) => {
1228 fit_latent_binary_model(request).map(FitResult::LatentBinary)
1229 }
1230 FitRequest::TransformationNormal(request) => {
1231 fit_transformation_normal_model(request).map(FitResult::TransformationNormal)
1232 }
1233 }
1234}
1235
1236use crate::families::family_meta::{family_to_string, is_binomial_family};
1241use crate::families::survival_construction::{
1242 SurvivalBaselineTarget, SurvivalLikelihoodMode, SurvivalTimeBasisConfig,
1243 add_survival_time_derivative_guard_offset, append_zero_tail_columns,
1244 build_latent_survival_baseline_offsets, build_survival_time_basis,
1245 build_survival_time_offsets_for_likelihood, build_survival_timewiggle_from_baseline,
1246 build_time_varying_survival_covariate_template, center_survival_time_designs_at_anchor,
1247 evaluate_survival_time_basis_row, initial_survival_baseline_config_for_fit,
1248 marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
1249 normalize_survival_time_pair, optimize_survival_baseline_config,
1250 optimize_survival_baseline_config_with_gradient, parse_survival_distribution,
1251 parse_survival_likelihood_mode, parse_survival_time_basis_config, positive_survival_time_seed,
1252 require_structural_survival_time_basis, resolve_survival_time_anchor_value,
1253 resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
1254};
1255use crate::families::survival_location_scale::{
1256 SurvivalCovariateTermBlockTemplate, TimeBlockInput, TimeWiggleBlockInput,
1257 residual_distribution_inverse_link,
1258};
1259use crate::inference::data::EncodedDataset as Dataset;
1260use crate::inference::formula_dsl::{
1261 LinkChoice, LinkWiggleFormulaSpec, ParsedFormula, ParsedTerm, effectivelinkwiggle_formulaspec,
1262 parse_formula, parse_link_choice, parse_matching_auxiliary_formula, parse_surv_response,
1263 require_inverse_link_supports_joint_wiggle, validate_marginal_slope_z_column_exclusion,
1264};
1265use crate::term_builder::{
1266 build_termspec, column_map_with_alias, enable_scale_dimensions, resolve_role_col,
1267};
1268
1269#[derive(Clone, Debug)]
1271pub struct FitConfig {
1272 pub family: Option<String>,
1274 pub link: Option<String>,
1276 pub flexible_link: bool,
1278 pub offset_column: Option<String>,
1280 pub noise_offset_column: Option<String>,
1282 pub frailty: Option<FrailtySpec>,
1284
1285 pub baseline_target: String,
1288 pub baseline_scale: Option<f64>,
1289 pub baseline_shape: Option<f64>,
1290 pub baseline_rate: Option<f64>,
1291 pub baseline_makeham: Option<f64>,
1292 pub time_basis: String,
1294 pub time_degree: usize,
1295 pub time_num_internal_knots: usize,
1296 pub time_smooth_lambda: f64,
1297 pub survival_likelihood: String,
1300 pub survival_distribution: String,
1302 pub threshold_time_k: Option<usize>,
1303 pub threshold_time_degree: usize,
1304 pub sigma_time_k: Option<usize>,
1305 pub sigma_time_degree: usize,
1306
1307 pub noise_formula: Option<String>,
1310
1311 pub logslope_formula: Option<String>,
1314 pub z_column: Option<String>,
1316 pub weight_column: Option<String>,
1318
1319 pub scale_dimensions: bool,
1321 pub ridge_lambda: f64,
1322
1323 pub transformation_normal: bool,
1328
1329 pub firth: bool,
1331
1332 pub resource_policy: Option<crate::resource::ResourcePolicy>,
1336}
1337
1338impl Default for FitConfig {
1339 fn default() -> Self {
1340 Self {
1341 family: None,
1342 link: None,
1343 flexible_link: false,
1344 offset_column: None,
1345 noise_offset_column: None,
1346 frailty: None,
1347 baseline_target: "linear".into(),
1348 baseline_scale: None,
1349 baseline_shape: None,
1350 baseline_rate: None,
1351 baseline_makeham: None,
1352 time_basis: "ispline".into(),
1353 time_degree: 3,
1354 time_num_internal_knots: 8,
1355 time_smooth_lambda: 1e-2,
1356 survival_likelihood: "location-scale".into(),
1357 survival_distribution: "gaussian".into(),
1358 threshold_time_k: None,
1359 threshold_time_degree: 3,
1360 sigma_time_k: None,
1361 sigma_time_degree: 3,
1362 noise_formula: None,
1363 logslope_formula: None,
1364 z_column: None,
1365 weight_column: None,
1366 scale_dimensions: false,
1367 ridge_lambda: 1e-6,
1368 transformation_normal: false,
1369 firth: false,
1370 resource_policy: None,
1371 }
1372 }
1373}
1374
1375fn resolved_resource_policy(config: &FitConfig) -> crate::resource::ResourcePolicy {
1379 config
1380 .resource_policy
1381 .clone()
1382 .unwrap_or_else(crate::resource::ResourcePolicy::default_library)
1383}
1384
1385pub struct MaterializedModel<'a> {
1387 pub request: FitRequest<'a>,
1388 pub inference_notes: Vec<String>,
1389}
1390
1391pub fn fit_from_formula(
1393 formula: &str,
1394 data: &Dataset,
1395 config: &FitConfig,
1396) -> Result<FitResult, String> {
1397 let mat = materialize(formula, data, config)?;
1398 fit_model(mat.request)
1399}
1400
1401pub fn materialize<'a>(
1403 formula: &str,
1404 data: &'a Dataset,
1405 config: &FitConfig,
1406) -> Result<MaterializedModel<'a>, String> {
1407 let parsed = parse_formula(formula)?;
1408 let col_map = data.column_map();
1409
1410 if let Some((entry_col, exit_col, event_col)) = parse_surv_response(&parsed.response)? {
1411 if config.transformation_normal {
1412 return Err(
1413 "transformation_normal cannot be combined with a Surv(...) response".to_string(),
1414 );
1415 }
1416 materialize_survival(
1417 &parsed, data, &col_map, config, &entry_col, &exit_col, &event_col,
1418 )
1419 } else if config.transformation_normal {
1420 if config.noise_formula.is_some() {
1421 return Err("transformation_normal cannot be combined with noise_formula".to_string());
1422 }
1423 materialize_transformation_normal(&parsed, data, &col_map, config)
1424 } else if config.logslope_formula.is_some() || config.z_column.is_some() {
1425 materialize_bernoulli_marginal_slope(&parsed, data, &col_map, config)
1426 } else if config.noise_formula.is_some() {
1427 materialize_location_scale(&parsed, data, &col_map, config)
1428 } else {
1429 materialize_standard(&parsed, data, &col_map, config)
1430 }
1431}
1432
1433pub fn is_binary_response(y: ArrayView1<'_, f64>) -> bool {
1435 if y.is_empty() {
1436 return false;
1437 }
1438 y.iter()
1439 .all(|v| (*v - 0.0).abs() < 1e-12 || (*v - 1.0).abs() < 1e-12)
1440}
1441
1442pub fn resolve_family(
1444 family: Option<&str>,
1445 link_choice: Option<&LinkChoice>,
1446 y: ArrayView1<'_, f64>,
1447) -> Result<LikelihoodFamily, String> {
1448 let explicit = family.and_then(|name| match name.to_ascii_lowercase().as_str() {
1449 "gaussian" => Some(LikelihoodFamily::GaussianIdentity),
1450 "binomial" | "binomial-logit" => Some(LikelihoodFamily::BinomialLogit),
1451 "binomial-probit" => Some(LikelihoodFamily::BinomialProbit),
1452 "binomial-cloglog" => Some(LikelihoodFamily::BinomialCLogLog),
1453 "latent-cloglog-binomial" => Some(LikelihoodFamily::BinomialLatentCLogLog),
1454 "poisson" => Some(LikelihoodFamily::PoissonLog),
1455 "gamma" => Some(LikelihoodFamily::GammaLog),
1456 _ => None,
1457 });
1458
1459 if let Some(choice) = link_choice {
1460 let from_link = if choice.mixture_components.is_some() {
1461 LikelihoodFamily::BinomialMixture
1462 } else {
1463 match choice.link {
1464 LinkFunction::Identity => LikelihoodFamily::GaussianIdentity,
1465 LinkFunction::Log => {
1466 if y.iter()
1467 .all(|&yi| yi.is_finite() && yi >= 0.0 && (yi - yi.round()).abs() <= 1e-9)
1468 {
1469 LikelihoodFamily::PoissonLog
1470 } else {
1471 LikelihoodFamily::GammaLog
1472 }
1473 }
1474 LinkFunction::Logit => LikelihoodFamily::BinomialLogit,
1475 LinkFunction::Probit => LikelihoodFamily::BinomialProbit,
1476 LinkFunction::CLogLog => LikelihoodFamily::BinomialCLogLog,
1477 LinkFunction::Sas => LikelihoodFamily::BinomialSas,
1478 LinkFunction::BetaLogistic => LikelihoodFamily::BinomialBetaLogistic,
1479 }
1480 };
1481 if let Some(explicit_family) = explicit {
1482 if explicit_family != from_link {
1483 return Err(format!(
1484 "family '{}' conflicts with link",
1485 family_to_string(explicit_family)
1486 ));
1487 }
1488 }
1489 return Ok(from_link);
1490 }
1491
1492 if let Some(f) = explicit {
1493 return Ok(f);
1494 }
1495
1496 if is_binary_response(y) {
1498 Ok(LikelihoodFamily::BinomialLogit)
1499 } else {
1500 Ok(LikelihoodFamily::GaussianIdentity)
1501 }
1502}
1503
1504fn build_termspec_with_geometry(
1509 terms: &[ParsedTerm],
1510 data: &Dataset,
1511 col_map: &HashMap<String, usize>,
1512 inference_notes: &mut Vec<String>,
1513 scale_dimensions: bool,
1514 policy: &crate::resource::ResourcePolicy,
1515) -> Result<TermCollectionSpec, String> {
1516 let mut spec = build_termspec(terms, data, col_map, inference_notes, policy)?;
1517 if scale_dimensions {
1518 enable_scale_dimensions(&mut spec);
1519 }
1520 Ok(spec)
1521}
1522
1523fn resolve_survival_marginal_slope_base_link(
1524 linkspec: Option<&crate::inference::formula_dsl::LinkFormulaSpec>,
1525) -> Result<InverseLink, String> {
1526 let Some(linkspec) = linkspec else {
1527 return Ok(InverseLink::Standard(LinkFunction::Probit));
1528 };
1529 let choice = parse_link_choice(Some(&linkspec.link), false)?
1530 .ok_or_else(|| "invalid survival marginal-slope link".to_string())?;
1531 if choice.mixture_components.is_some() {
1532 return Err(
1533 "survival marginal-slope currently supports only link(type=probit)".to_string(),
1534 );
1535 }
1536 match choice.link {
1537 LinkFunction::Probit => Ok(InverseLink::Standard(LinkFunction::Probit)),
1538 other => Err(format!(
1539 "survival marginal-slope currently supports only link(type=probit), got {other:?}"
1540 )),
1541 }
1542}
1543
1544struct PreparedWorkflowSurvivalTimeStack {
1545 eta_offset_entry: Array1<f64>,
1546 eta_offset_exit: Array1<f64>,
1547 derivative_offset_exit: Array1<f64>,
1548 unloaded_mass_entry: Array1<f64>,
1549 unloaded_mass_exit: Array1<f64>,
1550 unloaded_hazard_exit: Array1<f64>,
1551 time_design_entry: crate::matrix::DesignMatrix,
1552 time_design_exit: crate::matrix::DesignMatrix,
1553 time_design_derivative: crate::matrix::DesignMatrix,
1554 time_penalties: Vec<Array2<f64>>,
1555 time_nullspace_dims: Vec<usize>,
1556 timewiggle_block: Option<TimeWiggleBlockInput>,
1557}
1558
1559fn prepare_workflow_survival_time_stack(
1560 age_entry: &Array1<f64>,
1561 age_exit: &Array1<f64>,
1562 baseline_cfg: &crate::families::survival_construction::SurvivalBaselineConfig,
1563 likelihood_mode: SurvivalLikelihoodMode,
1564 inverse_link: Option<&InverseLink>,
1565 time_anchor: f64,
1566 derivative_guard: f64,
1567 time_build: &crate::families::survival_construction::SurvivalTimeBuildOutput,
1568 effective_timewiggle: Option<&LinkWiggleFormulaSpec>,
1569 latent_loading: Option<crate::families::lognormal_kernel::HazardLoading>,
1570) -> Result<PreparedWorkflowSurvivalTimeStack, String> {
1571 let (
1572 mut eta_offset_entry,
1573 mut eta_offset_exit,
1574 mut derivative_offset_exit,
1575 unloaded_mass_entry,
1576 unloaded_mass_exit,
1577 unloaded_hazard_exit,
1578 ) = if let Some(loading) = latent_loading {
1579 let offsets =
1580 build_latent_survival_baseline_offsets(age_entry, age_exit, baseline_cfg, loading)?;
1581 (
1582 offsets.loaded_eta_entry,
1583 offsets.loaded_eta_exit,
1584 offsets.loaded_derivative_exit,
1585 offsets.unloaded_mass_entry,
1586 offsets.unloaded_mass_exit,
1587 offsets.unloaded_hazard_exit,
1588 )
1589 } else {
1590 let (eta_offset_entry, eta_offset_exit, derivative_offset_exit) =
1591 build_survival_time_offsets_for_likelihood(
1592 age_entry,
1593 age_exit,
1594 baseline_cfg,
1595 likelihood_mode,
1596 inverse_link,
1597 )?;
1598 let n = age_entry.len();
1599 (
1600 eta_offset_entry,
1601 eta_offset_exit,
1602 derivative_offset_exit,
1603 Array1::zeros(n),
1604 Array1::zeros(n),
1605 Array1::zeros(n),
1606 )
1607 };
1608 add_survival_time_derivative_guard_offset(
1609 age_entry,
1610 age_exit,
1611 time_anchor,
1612 derivative_guard,
1613 &mut eta_offset_entry,
1614 &mut eta_offset_exit,
1615 &mut derivative_offset_exit,
1616 )?;
1617 let timewiggle_build = if let Some(cfg) = effective_timewiggle {
1618 Some(build_survival_timewiggle_from_baseline(
1619 &eta_offset_entry,
1620 &eta_offset_exit,
1621 &derivative_offset_exit,
1622 cfg,
1623 )?)
1624 } else {
1625 None
1626 };
1627 let mut time_design_entry = time_build.x_entry_time.clone();
1628 let mut time_design_exit = time_build.x_exit_time.clone();
1629 let mut time_design_derivative = time_build.x_derivative_time.clone();
1630 let mut time_penalties = time_build.penalties.clone();
1631 let mut time_nullspace_dims = time_build.nullspace_dims.clone();
1632 let mut timewiggle_block = None;
1633 if let Some(wiggle) = timewiggle_build.as_ref() {
1634 let p_base = time_design_exit.ncols();
1635 append_zero_tail_columns(
1636 &mut time_design_entry,
1637 &mut time_design_exit,
1638 &mut time_design_derivative,
1639 wiggle.ncols,
1640 );
1641 for (idx, penalty) in wiggle.penalties.iter().enumerate() {
1642 let mut embedded = Array2::<f64>::zeros((p_base + wiggle.ncols, p_base + wiggle.ncols));
1643 embedded
1644 .slice_mut(s![
1645 p_base..p_base + wiggle.ncols,
1646 p_base..p_base + wiggle.ncols
1647 ])
1648 .assign(penalty);
1649 time_penalties.push(embedded);
1650 time_nullspace_dims.push(wiggle.nullspace_dims.get(idx).copied().unwrap_or(0));
1651 }
1652 timewiggle_block = Some(TimeWiggleBlockInput {
1653 knots: wiggle.knots.clone(),
1654 degree: wiggle.degree,
1655 ncols: wiggle.ncols,
1656 });
1657 }
1658 Ok(PreparedWorkflowSurvivalTimeStack {
1659 eta_offset_entry,
1660 eta_offset_exit,
1661 derivative_offset_exit,
1662 unloaded_mass_entry,
1663 unloaded_mass_exit,
1664 unloaded_hazard_exit,
1665 time_design_entry,
1666 time_design_exit,
1667 time_design_derivative,
1668 time_penalties,
1669 time_nullspace_dims,
1670 timewiggle_block,
1671 })
1672}
1673
1674fn resolve_continuous_column(
1675 data: &Dataset,
1676 col_map: &HashMap<String, usize>,
1677 column_name: &str,
1678 role: &str,
1679) -> Result<Array1<f64>, String> {
1680 let col_idx = resolve_role_col(col_map, column_name, role)?;
1681 let values = data.values.column(col_idx).to_owned();
1682 for (row_idx, value) in values.iter().enumerate() {
1683 if !value.is_finite() {
1684 return Err(format!(
1685 "{role} column '{column_name}' contains non-finite value at row {row_idx}: {value}"
1686 ));
1687 }
1688 }
1689 Ok(values)
1690}
1691
1692pub fn resolve_offset_column(
1693 data: &Dataset,
1694 col_map: &HashMap<String, usize>,
1695 column_name: Option<&str>,
1696) -> Result<Array1<f64>, String> {
1697 let Some(column_name) = column_name else {
1698 return Ok(Array1::zeros(data.values.nrows()));
1699 };
1700 resolve_continuous_column(data, col_map, column_name, "offset")
1701}
1702
1703pub fn resolve_weight_column(
1704 data: &Dataset,
1705 col_map: &HashMap<String, usize>,
1706 column_name: Option<&str>,
1707) -> Result<Array1<f64>, String> {
1708 let Some(column_name) = column_name else {
1709 return Ok(Array1::ones(data.values.nrows()));
1710 };
1711 let values = resolve_continuous_column(data, col_map, column_name, "weights")?;
1712 for (row_idx, value) in values.iter().enumerate() {
1713 if *value < 0.0 {
1714 return Err(format!(
1715 "weights column '{column_name}' must be non-negative; found {value} at row {row_idx}"
1716 ));
1717 }
1718 }
1719 Ok(values)
1720}
1721
1722fn materialize_standard<'a>(
1723 parsed: &ParsedFormula,
1724 data: &'a Dataset,
1725 col_map: &HashMap<String, usize>,
1726 config: &FitConfig,
1727) -> Result<MaterializedModel<'a>, String> {
1728 if config.noise_offset_column.is_some() {
1729 return Err(
1730 "noise_offset_column requires a location-scale model with noise_formula".to_string(),
1731 );
1732 }
1733 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
1734 let y = data.values.column(y_col).to_owned();
1735 let mut inference_notes = Vec::new();
1736
1737 let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
1738 let family = resolve_family(config.family.as_deref(), link_choice.as_ref(), y.view())?;
1739
1740 let effective_linkwiggle =
1741 effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
1742
1743 let policy = resolved_resource_policy(config);
1744 let spec = build_termspec_with_geometry(
1745 &parsed.terms,
1746 data,
1747 col_map,
1748 &mut inference_notes,
1749 config.scale_dimensions,
1750 &policy,
1751 )?;
1752
1753 let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
1754 let offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
1755 let latent_cloglog = if matches!(family, LikelihoodFamily::BinomialLatentCLogLog) {
1756 let sigma = match config.frailty.clone().unwrap_or(FrailtySpec::None) {
1757 FrailtySpec::HazardMultiplier {
1758 sigma_fixed: Some(sigma),
1759 loading: crate::families::lognormal_kernel::HazardLoading::Full,
1760 } => sigma,
1761 FrailtySpec::HazardMultiplier {
1762 sigma_fixed: Some(_),
1763 loading,
1764 } => {
1765 return Err(format!(
1766 "latent-cloglog-binomial requires HazardLoading::Full, got {loading:?}"
1767 ));
1768 }
1769 FrailtySpec::HazardMultiplier {
1770 sigma_fixed: None, ..
1771 } => {
1772 return Err(
1773 "latent-cloglog-binomial currently requires a fixed hazard-multiplier sigma"
1774 .to_string(),
1775 );
1776 }
1777 FrailtySpec::GaussianShift { .. } => {
1778 return Err(
1779 "latent-cloglog-binomial does not support GaussianShift frailty".to_string(),
1780 );
1781 }
1782 FrailtySpec::None => {
1783 return Err(
1784 "latent-cloglog-binomial requires config.frailty=HazardMultiplier with a fixed sigma"
1785 .to_string(),
1786 );
1787 }
1788 };
1789 Some(
1790 LatentCLogLogState::new(sigma)
1791 .map_err(|e| format!("invalid latent_cloglog state: {e}"))?,
1792 )
1793 } else {
1794 if config.frailty.is_some() {
1795 return Err(format!(
1796 "config.frailty is not supported for standard family {:?}; use a frailty-aware family instead",
1797 family
1798 ));
1799 }
1800 None
1801 };
1802 let options = FitOptions {
1803 latent_cloglog,
1804 mixture_link: None,
1805 optimize_mixture: false,
1806 sas_link: None,
1807 optimize_sas: false,
1808 compute_inference: true,
1809 max_iter: 200,
1810 tol: 1e-7,
1811 nullspace_dims: vec![],
1812 linear_constraints: None,
1813 firth_bias_reduction: config.firth,
1814 adaptive_regularization: None,
1815 penalty_shrinkage_floor: Some(1e-6),
1816 rho_prior: Default::default(),
1817 kronecker_penalty_system: None,
1818 kronecker_factored: None,
1819 };
1820 let kappa_options = SpatialLengthScaleOptimizationOptions::default();
1821
1822 let wiggle = effective_linkwiggle.as_ref().and_then(|cfg| {
1823 if !is_binomial_family(family) {
1824 return None;
1825 }
1826 let link_kind = link_choice
1827 .as_ref()
1828 .map(|c| InverseLink::Standard(c.link))
1829 .unwrap_or_else(|| {
1830 if let Some(state) = latent_cloglog {
1831 InverseLink::LatentCLogLog(state)
1832 } else {
1833 InverseLink::Standard(LinkFunction::Logit)
1834 }
1835 });
1836 Some(StandardBinomialWiggleConfig {
1837 link_kind,
1838 wiggle: LinkWiggleConfig {
1839 degree: cfg.degree,
1840 num_internal_knots: cfg.num_internal_knots,
1841 penalty_orders: cfg.penalty_orders.clone(),
1842 double_penalty: cfg.double_penalty,
1843 },
1844 })
1845 });
1846
1847 Ok(MaterializedModel {
1848 request: FitRequest::Standard(StandardFitRequest {
1849 data: data.values.view(),
1850 y,
1851 weights,
1852 offset,
1853 spec,
1854 family,
1855 options,
1856 kappa_options,
1857 wiggle,
1858 wiggle_options: None,
1859 }),
1860 inference_notes,
1861 })
1862}
1863
1864fn materialize_bernoulli_marginal_slope<'a>(
1865 parsed: &ParsedFormula,
1866 data: &'a Dataset,
1867 col_map: &HashMap<String, usize>,
1868 config: &FitConfig,
1869) -> Result<MaterializedModel<'a>, String> {
1870 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
1871 let y = data.values.column(y_col).to_owned();
1872
1873 if !is_binary_response(y.view()) {
1874 return Err("Bernoulli marginal-slope requires a binary {0,1} response".to_string());
1875 }
1876 if config.noise_formula.is_some() {
1877 return Err("Bernoulli marginal-slope cannot also use noise_formula".to_string());
1878 }
1879
1880 let logslope_formula = config
1881 .logslope_formula
1882 .as_deref()
1883 .ok_or_else(|| "Bernoulli marginal-slope requires logslope_formula".to_string())?;
1884 let z_column = config
1885 .z_column
1886 .as_deref()
1887 .ok_or_else(|| "Bernoulli marginal-slope requires z_column".to_string())?;
1888
1889 let (_, parsed_logslope) =
1890 parse_matching_auxiliary_formula(logslope_formula, &parsed.response, "logslope_formula")?;
1891 if parsed_logslope.linkspec.is_some() {
1892 return Err("link(...) is not supported inside logslope_formula".to_string());
1893 }
1894 validate_marginal_slope_z_column_exclusion(
1895 parsed,
1896 &parsed_logslope,
1897 z_column,
1898 "Bernoulli marginal-slope",
1899 "logslope_formula",
1900 )?;
1901
1902 let mut inference_notes = Vec::new();
1903 let policy = resolved_resource_policy(config);
1904 let aliased_col_map = column_map_with_alias(col_map, "z", z_column);
1905 let marginalspec = build_termspec_with_geometry(
1906 &parsed.terms,
1907 data,
1908 &aliased_col_map,
1909 &mut inference_notes,
1910 config.scale_dimensions,
1911 &policy,
1912 )?;
1913 let logslopespec = build_termspec_with_geometry(
1914 &parsed_logslope.terms,
1915 data,
1916 &aliased_col_map,
1917 &mut inference_notes,
1918 config.scale_dimensions,
1919 &policy,
1920 )?;
1921 let z_idx = resolve_role_col(col_map, z_column, "z")?;
1922 let z = data.values.column(z_idx).to_owned();
1923 let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
1924 let marginal_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
1925 let logslope_offset =
1926 resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
1927 let routing = route_marginal_slope_deviation_blocks(
1928 parsed.linkwiggle.as_ref(),
1929 parsed_logslope.linkwiggle.as_ref(),
1930 )?;
1931 let spec = BernoulliMarginalSlopeTermSpec {
1932 y,
1933 weights,
1934 z,
1935 base_link: InverseLink::Standard(LinkFunction::Probit),
1936 marginalspec,
1937 logslopespec,
1938 marginal_offset,
1939 logslope_offset,
1940 frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
1941 score_warp: routing.score_warp,
1942 link_dev: routing.link_dev,
1943 latent_z_policy: Default::default(),
1944 };
1945
1946 Ok(MaterializedModel {
1947 request: FitRequest::BernoulliMarginalSlope(BernoulliMarginalSlopeFitRequest {
1948 data: data.values.view(),
1949 spec,
1950 options: BlockwiseFitOptions {
1951 compute_covariance: true,
1952 ..Default::default()
1953 },
1954 kappa_options: SpatialLengthScaleOptimizationOptions::default(),
1955 policy,
1956 }),
1957 inference_notes,
1958 })
1959}
1960
1961fn materialize_survival<'a>(
1962 parsed: &ParsedFormula,
1963 data: &'a Dataset,
1964 col_map: &HashMap<String, usize>,
1965 config: &FitConfig,
1966 entry_col: &str,
1967 exit_col: &str,
1968 event_col: &str,
1969) -> Result<MaterializedModel<'a>, String> {
1970 let mut inference_notes = Vec::new();
1971
1972 let entry_idx = resolve_role_col(col_map, entry_col, "entry")?;
1974 let exit_idx = resolve_role_col(col_map, exit_col, "exit")?;
1975 let event_idx = resolve_role_col(col_map, event_col, "event")?;
1976 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1977 let n = data.values.nrows();
1978 let event = data.values.column(event_idx).to_owned();
1979 let pairs: Result<Vec<(f64, f64)>, String> = (0..n)
1980 .into_par_iter()
1981 .map(|i| {
1982 normalize_survival_time_pair(data.values[[i, entry_idx]], data.values[[i, exit_idx]], i)
1983 })
1984 .collect();
1985 let pairs = pairs?;
1986 let mut age_entry = Array1::<f64>::zeros(n);
1987 let mut age_exit = Array1::<f64>::zeros(n);
1988 for (i, (e, x)) in pairs.into_iter().enumerate() {
1989 age_entry[i] = e;
1990 age_exit[i] = x;
1991 }
1992
1993 let survival_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
1994 if parsed.linkwiggle.is_some()
1995 && !matches!(
1996 survival_mode,
1997 SurvivalLikelihoodMode::LocationScale | SurvivalLikelihoodMode::MarginalSlope
1998 )
1999 {
2000 return Err(format!(
2001 "linkwiggle(...) is not defined for survival_likelihood='{}'",
2002 config.survival_likelihood
2003 ));
2004 }
2005 if parsed.linkspec.is_some()
2006 && matches!(
2007 survival_mode,
2008 SurvivalLikelihoodMode::Transformation
2009 | SurvivalLikelihoodMode::Weibull
2010 | SurvivalLikelihoodMode::Latent
2011 | SurvivalLikelihoodMode::LatentBinary
2012 )
2013 {
2014 return Err(format!(
2015 "link(...) is not implemented for survival_likelihood='{}'",
2016 config.survival_likelihood
2017 ));
2018 }
2019 let effective_timewiggle = parsed.timewiggle.clone();
2020 let baseline_target_raw = match survival_mode {
2021 SurvivalLikelihoodMode::Weibull if effective_timewiggle.is_some() => "weibull",
2022 SurvivalLikelihoodMode::Weibull => "linear",
2023 _ => &config.baseline_target,
2024 };
2025 let baseline_cfg = initial_survival_baseline_config_for_fit(
2026 baseline_target_raw,
2027 config.baseline_scale,
2028 config.baseline_shape,
2029 config.baseline_rate,
2030 config.baseline_makeham,
2031 &age_exit,
2032 )?;
2033 if matches!(
2034 survival_mode,
2035 SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
2036 ) && baseline_cfg.target == SurvivalBaselineTarget::Linear
2037 {
2038 return Err(
2039 "latent hazard-window families require a non-linear scalar baseline target; use baseline_target weibull, gompertz, or gompertz-makeham"
2040 .to_string(),
2041 );
2042 }
2043 let time_cfg = if effective_timewiggle.is_some() {
2044 SurvivalTimeBasisConfig::None
2047 } else if survival_mode == SurvivalLikelihoodMode::Weibull {
2048 SurvivalTimeBasisConfig::Linear
2049 } else {
2050 parse_survival_time_basis_config(
2051 &config.time_basis,
2052 config.time_degree,
2053 config.time_num_internal_knots,
2054 config.time_smooth_lambda,
2055 )?
2056 };
2057 let time_anchor = resolve_survival_time_anchor_value(&age_entry, None)?;
2058 let exact_derivative_guard = survival_derivative_guard_for_likelihood(survival_mode);
2059
2060 let mut time_build = build_survival_time_basis(
2062 &age_entry,
2063 &age_exit,
2064 time_cfg.clone(),
2065 Some((config.time_num_internal_knots, config.time_smooth_lambda)),
2066 )?;
2067 if survival_mode != SurvivalLikelihoodMode::Weibull && effective_timewiggle.is_none() {
2068 require_structural_survival_time_basis(&time_build.basisname, "workflow survival fitting")?;
2069 }
2070 let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
2071 &time_build.basisname,
2072 time_build.degree,
2073 time_build.knots.as_ref(),
2074 time_build.keep_cols.as_ref(),
2075 time_build.smooth_lambda,
2076 )?;
2077 let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
2078 center_survival_time_designs_at_anchor(
2079 &mut time_build.x_entry_time,
2080 &mut time_build.x_exit_time,
2081 &time_anchor_row,
2082 )?;
2083 if effective_timewiggle.is_some() && baseline_cfg.target == SurvivalBaselineTarget::Linear {
2084 return Err(
2085 "timewiggle requires a non-linear scalar survival baseline target; \
2086 use baseline_target weibull, gompertz, or gompertz-makeham"
2087 .to_string(),
2088 );
2089 }
2090
2091 let policy = resolved_resource_policy(config);
2092 let marginal_slope_aliased_col_map = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2093 Some(column_map_with_alias(
2094 col_map,
2095 "z",
2096 config.z_column.as_deref().ok_or_else(|| {
2097 "marginal-slope survival requires z_column in FitConfig".to_string()
2098 })?,
2099 ))
2100 } else {
2101 None
2102 };
2103 let termspec_col_map = marginal_slope_aliased_col_map.as_ref().unwrap_or(col_map);
2104 let termspec = build_termspec_with_geometry(
2105 &parsed.terms,
2106 data,
2107 termspec_col_map,
2108 &mut inference_notes,
2109 config.scale_dimensions,
2110 &policy,
2111 )?;
2112
2113 let residual_dist = parse_survival_distribution(&config.survival_distribution)?;
2114 let survival_inverse_link = residual_distribution_inverse_link(residual_dist);
2115 let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
2116 let effective_linkwiggle =
2117 effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
2118 let effective_linkwiggle_cfg = effective_linkwiggle.clone().map(|cfg| LinkWiggleConfig {
2119 degree: cfg.degree,
2120 num_internal_knots: cfg.num_internal_knots,
2121 penalty_orders: cfg.penalty_orders,
2122 double_penalty: cfg.double_penalty,
2123 });
2124
2125 let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
2126 let threshold_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
2127 let log_sigma_offset =
2128 resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
2129 let threshold_template = if let Some(k) = config.threshold_time_k {
2130 build_time_varying_survival_covariate_template(
2131 &age_entry,
2132 &age_exit,
2133 k,
2134 config.threshold_time_degree,
2135 "threshold",
2136 )?
2137 } else {
2138 SurvivalCovariateTermBlockTemplate::Static
2139 };
2140 let log_sigma_template = if let Some(k) = config.sigma_time_k {
2141 build_time_varying_survival_covariate_template(
2142 &age_entry,
2143 &age_exit,
2144 k,
2145 config.sigma_time_degree,
2146 "sigma",
2147 )?
2148 } else {
2149 SurvivalCovariateTermBlockTemplate::Static
2150 };
2151 let log_sigmaspec = if let Some(noise) = config.noise_formula.as_deref() {
2152 let noise_parsed = parse_formula(&format!("{} ~ {noise}", parsed.response))?;
2153 build_termspec_with_geometry(
2157 &noise_parsed.terms,
2158 data,
2159 termspec_col_map,
2160 &mut inference_notes,
2161 config.scale_dimensions,
2162 &policy,
2163 )?
2164 } else if survival_mode == SurvivalLikelihoodMode::LocationScale {
2165 termspec.clone()
2166 } else {
2167 TermCollectionSpec {
2168 linear_terms: vec![],
2169 random_effect_terms: vec![],
2170 smooth_terms: vec![],
2171 }
2172 };
2173 let marginal_z_column_name =
2174 if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2175 Some(config.z_column.as_deref().ok_or_else(|| {
2176 "marginal-slope survival requires z_column in FitConfig".to_string()
2177 })?)
2178 } else {
2179 None
2180 };
2181 let marginal_z = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2182 let _base_link = resolve_survival_marginal_slope_base_link(parsed.linkspec.as_ref())?;
2183 let z_col_name = marginal_z_column_name
2184 .expect("marginal-slope z column should be validated before materialization");
2185 let z_idx = resolve_role_col(col_map, z_col_name, "z")?;
2186 Some(data.values.column(z_idx).to_owned())
2187 } else {
2188 None
2189 };
2190 let (marginal_logslopespec, marginal_slope_deviation_routing) = if survival_mode
2191 == SurvivalLikelihoodMode::MarginalSlope
2192 {
2193 if let Some(ls_formula) = config.logslope_formula.as_deref() {
2194 let (_, ls_parsed) =
2195 parse_matching_auxiliary_formula(ls_formula, &parsed.response, "logslope_formula")?;
2196 if ls_parsed.linkspec.is_some() {
2197 return Err(
2198 "link(...) is not supported in logslope_formula for the survival marginal-slope family"
2199 .to_string(),
2200 );
2201 }
2202 if ls_parsed.timewiggle.is_some() {
2203 return Err(
2204 "timewiggle(...) is not supported in logslope_formula for the survival marginal-slope family"
2205 .to_string(),
2206 );
2207 }
2208 if ls_parsed.survivalspec.is_some() {
2209 return Err(
2210 "survmodel(...) is not supported in logslope_formula for the survival marginal-slope family"
2211 .to_string(),
2212 );
2213 }
2214 validate_marginal_slope_z_column_exclusion(
2215 parsed,
2216 &ls_parsed,
2217 marginal_z_column_name.expect("marginal-slope z column should be available"),
2218 "survival marginal-slope",
2219 "logslope_formula",
2220 )?;
2221 (
2222 Some(build_termspec_with_geometry(
2223 &ls_parsed.terms,
2224 data,
2225 marginal_slope_aliased_col_map
2226 .as_ref()
2227 .expect("marginal-slope column map should be available"),
2228 &mut inference_notes,
2229 config.scale_dimensions,
2230 &policy,
2231 )?),
2232 route_marginal_slope_deviation_blocks(
2233 parsed.linkwiggle.as_ref(),
2234 ls_parsed.linkwiggle.as_ref(),
2235 )?,
2236 )
2237 } else {
2238 validate_marginal_slope_z_column_exclusion(
2239 parsed,
2240 parsed,
2241 marginal_z_column_name.expect("marginal-slope z column should be available"),
2242 "survival marginal-slope",
2243 "logslope_formula",
2244 )?;
2245 (
2246 Some(termspec.clone()),
2247 route_marginal_slope_deviation_blocks(parsed.linkwiggle.as_ref(), None)?,
2248 )
2249 }
2250 } else {
2251 (
2252 None,
2253 MarginalSlopeDeviationRouting {
2254 score_warp: None,
2255 link_dev: None,
2256 },
2257 )
2258 };
2259 let marginal_slope_score_warp = marginal_slope_deviation_routing.score_warp;
2260 let marginal_slope_link_dev = marginal_slope_deviation_routing.link_dev;
2261 if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2262 if parsed.linkwiggle.is_some() {
2263 inference_notes.push(
2264 "survival marginal-slope routes formula-level linkwiggle(...) into its anchored internal link-deviation block while keeping the probit survival base link".to_string(),
2265 );
2266 }
2267 if marginal_slope_score_warp.is_some() {
2268 inference_notes.push(
2269 "survival marginal-slope routes logslope_formula linkwiggle(...) into its anchored internal score-warp block while keeping the probit survival base link".to_string(),
2270 );
2271 }
2272 if marginal_slope_link_dev.is_none() && marginal_slope_score_warp.is_none() {
2273 inference_notes.push(
2274 "survival marginal-slope rigid mode is algebraic closed-form exact".to_string(),
2275 );
2276 } else {
2277 inference_notes.push(
2278 "survival marginal-slope flexible score/link mode uses calibrated de-nested cubic transport cells with analytic value evaluation and calibrated survival normalization"
2279 .to_string(),
2280 );
2281 }
2282 }
2283 let marginal_slope_frailty = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2284 Some(fixed_gaussian_shift_frailty_from_spec(
2285 config.frailty.as_ref().unwrap_or(&FrailtySpec::None),
2286 "survival marginal-slope",
2287 )?)
2288 } else {
2289 None
2290 };
2291 match survival_mode {
2292 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
2293 if config.frailty.is_some() =>
2294 {
2295 return Err(
2296 "frailty is not supported for transformation/weibull survival models".to_string(),
2297 );
2298 }
2299 SurvivalLikelihoodMode::LocationScale if config.frailty.is_some() => {
2300 return Err(
2301 "config.frailty is not implemented for survival-likelihood=location-scale"
2302 .to_string(),
2303 );
2304 }
2305 SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
2306 if effective_timewiggle.is_some() =>
2307 {
2308 return Err(
2309 "timewiggle is not implemented for latent survival/binary likelihoods".to_string(),
2310 );
2311 }
2312 _ => {}
2313 }
2314 let latent_loading = if matches!(
2315 survival_mode,
2316 SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
2317 ) {
2318 let frailty = config.frailty.as_ref().unwrap_or(&FrailtySpec::None);
2319 Some(latent_hazard_loading(
2320 frailty,
2321 "workflow latent survival/binary",
2322 )?)
2323 } else {
2324 None
2325 };
2326
2327 let build_time_block =
2328 |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2329 let prepared = prepare_workflow_survival_time_stack(
2330 &age_entry,
2331 &age_exit,
2332 candidate,
2333 survival_mode,
2334 (survival_mode == SurvivalLikelihoodMode::LocationScale)
2335 .then_some(&survival_inverse_link),
2336 time_anchor,
2337 exact_derivative_guard,
2338 &time_build,
2339 effective_timewiggle.as_ref(),
2340 None,
2341 )?;
2342 let time_p = prepared.time_design_exit.ncols();
2343 let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
2344 None
2345 } else {
2346 Some(Array1::from_elem(
2347 prepared.time_penalties.len(),
2348 config.time_smooth_lambda.ln(),
2349 ))
2350 };
2351 let time_block = TimeBlockInput {
2352 design_entry: prepared.time_design_entry.clone(),
2353 design_exit: prepared.time_design_exit.clone(),
2354 design_derivative_exit: prepared.time_design_derivative.clone(),
2355 offset_entry: prepared.eta_offset_entry.clone(),
2356 offset_exit: prepared.eta_offset_exit.clone(),
2357 derivative_offset_exit: prepared.derivative_offset_exit.clone(),
2358 structural_monotonicity: true,
2359 penalties: prepared.time_penalties.clone(),
2360 nullspace_dims: prepared.time_nullspace_dims.clone(),
2361 initial_log_lambdas: time_initial_log_lambdas,
2362 initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
2363 };
2364 Ok::<_, String>((prepared, time_block))
2365 };
2366
2367 let build_location_scale_request =
2368 |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2369 let (prepared, time_block) = build_time_block(candidate)?;
2370 let spec = SurvivalLocationScaleTermSpec {
2371 age_entry: age_entry.clone(),
2372 age_exit: age_exit.clone(),
2373 event_target: event.clone(),
2374 weights: weights.clone(),
2375 inverse_link: survival_inverse_link.clone(),
2376 derivative_guard: exact_derivative_guard,
2377 max_iter: 200,
2378 tol: 1e-7,
2379 time_block,
2380 thresholdspec: termspec.clone(),
2381 log_sigmaspec: log_sigmaspec.clone(),
2382 threshold_offset: threshold_offset.clone(),
2383 log_sigma_offset: log_sigma_offset.clone(),
2384 threshold_template: threshold_template.clone(),
2385 log_sigma_template: log_sigma_template.clone(),
2386 timewiggle_block: prepared.timewiggle_block,
2387 linkwiggle_block: None,
2388 };
2389 let optimize_inverse_link =
2390 survival_inverse_link_has_free_parameters(&spec.inverse_link);
2391 Ok::<_, String>(SurvivalLocationScaleFitRequest {
2392 data: data.values.view(),
2393 spec,
2394 wiggle: effective_linkwiggle_cfg.clone(),
2395 kappa_options: SpatialLengthScaleOptimizationOptions::default(),
2396 optimize_inverse_link,
2397 })
2398 };
2399
2400 let build_marginal_slope_request =
2401 |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2402 let (prepared, time_block) = build_time_block(candidate)?;
2403 Ok::<_, String>(SurvivalMarginalSlopeFitRequest {
2404 data: data.values.view(),
2405 spec: SurvivalMarginalSlopeTermSpec {
2406 age_entry: age_entry.clone(),
2407 age_exit: age_exit.clone(),
2408 event_target: event.clone(),
2409 weights: weights.clone(),
2410 z: marginal_z.clone().ok_or_else(|| {
2411 "marginal-slope survival requires z_column in FitConfig".to_string()
2412 })?,
2413 base_link: resolve_survival_marginal_slope_base_link(parsed.linkspec.as_ref())?,
2414 marginalspec: termspec.clone(),
2415 marginal_offset: threshold_offset.clone(),
2416 frailty: marginal_slope_frailty.clone().ok_or_else(|| {
2417 "internal error: marginal-slope frailty validation missing".to_string()
2418 })?,
2419 derivative_guard: exact_derivative_guard,
2420 time_block,
2421 timewiggle_block: prepared.timewiggle_block,
2422 logslopespec: marginal_logslopespec.clone().ok_or_else(|| {
2423 "marginal-slope survival is missing logslope spec".to_string()
2424 })?,
2425 logslope_offset: log_sigma_offset.clone(),
2426 score_warp: marginal_slope_score_warp.clone(),
2427 link_dev: marginal_slope_link_dev.clone(),
2428 latent_z_policy: Default::default(),
2429 },
2430 options: BlockwiseFitOptions {
2431 compute_covariance: false,
2432 ..Default::default()
2433 },
2434 kappa_options: SpatialLengthScaleOptimizationOptions::default(),
2435 })
2436 };
2437
2438 let build_latent_survival_request =
2439 |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2440 let loading = latent_loading.ok_or_else(|| {
2441 "internal error: latent survival loading missing after frailty validation"
2442 .to_string()
2443 })?;
2444 let prepared = prepare_workflow_survival_time_stack(
2445 &age_entry,
2446 &age_exit,
2447 candidate,
2448 survival_mode,
2449 None,
2450 time_anchor,
2451 exact_derivative_guard,
2452 &time_build,
2453 None,
2454 Some(loading),
2455 )?;
2456 let time_p = prepared.time_design_exit.ncols();
2457 let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
2458 None
2459 } else {
2460 Some(Array1::from_elem(
2461 prepared.time_penalties.len(),
2462 config.time_smooth_lambda.ln(),
2463 ))
2464 };
2465 let time_block = TimeBlockInput {
2466 design_entry: prepared.time_design_entry.clone(),
2467 design_exit: prepared.time_design_exit.clone(),
2468 design_derivative_exit: prepared.time_design_derivative.clone(),
2469 offset_entry: prepared.eta_offset_entry.clone(),
2470 offset_exit: prepared.eta_offset_exit.clone(),
2471 derivative_offset_exit: prepared.derivative_offset_exit.clone(),
2472 structural_monotonicity: true,
2473 penalties: prepared.time_penalties.clone(),
2474 nullspace_dims: prepared.time_nullspace_dims.clone(),
2475 initial_log_lambdas: time_initial_log_lambdas,
2476 initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
2477 };
2478 Ok::<_, String>(LatentSurvivalFitRequest {
2479 data: data.values.view(),
2480 spec: LatentSurvivalTermSpec {
2481 age_entry: age_entry.clone(),
2482 age_exit: age_exit.clone(),
2483 event_target: event.mapv(|v| if v >= 0.5 { 1 } else { 0 }),
2484 weights: weights.clone(),
2485 derivative_guard: exact_derivative_guard,
2486 time_block,
2487 unloaded_mass_entry: prepared.unloaded_mass_entry,
2488 unloaded_mass_exit: prepared.unloaded_mass_exit,
2489 unloaded_hazard_exit: prepared.unloaded_hazard_exit,
2490 meanspec: termspec.clone(),
2491 mean_offset: threshold_offset.clone(),
2492 },
2493 frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
2494 options: BlockwiseFitOptions::default(),
2495 })
2496 };
2497
2498 let build_latent_binary_request =
2499 |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2500 let loading = latent_loading.ok_or_else(|| {
2501 "internal error: latent binary loading missing after frailty validation".to_string()
2502 })?;
2503 let prepared = prepare_workflow_survival_time_stack(
2504 &age_entry,
2505 &age_exit,
2506 candidate,
2507 survival_mode,
2508 None,
2509 time_anchor,
2510 exact_derivative_guard,
2511 &time_build,
2512 None,
2513 Some(loading),
2514 )?;
2515 let time_p = prepared.time_design_exit.ncols();
2516 let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
2517 None
2518 } else {
2519 Some(Array1::from_elem(
2520 prepared.time_penalties.len(),
2521 config.time_smooth_lambda.ln(),
2522 ))
2523 };
2524 let time_block = TimeBlockInput {
2525 design_entry: prepared.time_design_entry.clone(),
2526 design_exit: prepared.time_design_exit.clone(),
2527 design_derivative_exit: prepared.time_design_derivative.clone(),
2528 offset_entry: prepared.eta_offset_entry.clone(),
2529 offset_exit: prepared.eta_offset_exit.clone(),
2530 derivative_offset_exit: prepared.derivative_offset_exit.clone(),
2531 structural_monotonicity: true,
2532 penalties: prepared.time_penalties.clone(),
2533 nullspace_dims: prepared.time_nullspace_dims.clone(),
2534 initial_log_lambdas: time_initial_log_lambdas,
2535 initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
2536 };
2537 Ok::<_, String>(LatentBinaryFitRequest {
2538 data: data.values.view(),
2539 spec: LatentBinaryTermSpec {
2540 age_entry: age_entry.clone(),
2541 age_exit: age_exit.clone(),
2542 event_target: event.mapv(|v| if v >= 0.5 { 1 } else { 0 }),
2543 weights: weights.clone(),
2544 derivative_guard: exact_derivative_guard,
2545 time_block,
2546 unloaded_mass_entry: prepared.unloaded_mass_entry,
2547 unloaded_mass_exit: prepared.unloaded_mass_exit,
2548 meanspec: termspec.clone(),
2549 mean_offset: threshold_offset.clone(),
2550 },
2551 frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
2552 options: BlockwiseFitOptions::default(),
2553 })
2554 };
2555
2556 let baseline_cfg = if matches!(
2557 survival_mode,
2558 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
2559 ) {
2560 baseline_cfg
2561 } else if baseline_cfg.target != SurvivalBaselineTarget::Linear
2562 && survival_mode == SurvivalLikelihoodMode::MarginalSlope
2563 {
2564 optimize_survival_baseline_config_with_gradient(
2565 &baseline_cfg,
2566 "workflow survival marginal-slope baseline",
2567 |candidate| {
2568 let fit =
2569 fit_survival_marginal_slope_model(build_marginal_slope_request(candidate)?)
2570 .map_err(|e| format!("survival marginal-slope fit failed: {e}"))?;
2571 let gradient = marginal_slope_baseline_chain_rule_gradient(
2572 age_entry.view(),
2573 age_exit.view(),
2574 candidate,
2575 &fit.baseline_offset_residuals,
2576 )?
2577 .ok_or_else(|| {
2578 "workflow survival marginal-slope baseline unexpectedly has no theta gradient"
2579 .to_string()
2580 })?;
2581 let hessian = marginal_slope_baseline_chain_rule_hessian(
2582 age_entry.view(),
2583 age_exit.view(),
2584 candidate,
2585 &fit.baseline_offset_residuals,
2586 &fit.baseline_offset_curvatures,
2587 )?
2588 .ok_or_else(|| {
2589 "workflow survival marginal-slope baseline unexpectedly has no theta Hessian"
2590 .to_string()
2591 })?;
2592 Ok((fit.fit.reml_score, gradient, hessian))
2593 },
2594 )?
2595 } else if baseline_cfg.target != SurvivalBaselineTarget::Linear {
2596 optimize_survival_baseline_config(
2597 &baseline_cfg,
2598 "workflow survival baseline",
2599 |candidate| match survival_mode {
2600 SurvivalLikelihoodMode::LocationScale => Ok(fit_survival_location_scale_model(
2601 build_location_scale_request(candidate)?,
2602 )
2603 .map_err(|e| format!("survival location-scale fit failed: {e}"))?
2604 .fit
2605 .fit
2606 .reml_score),
2607 SurvivalLikelihoodMode::MarginalSlope => unreachable!(
2608 "marginal-slope baseline profiling uses analytic GM-probit gradient"
2609 ),
2610 SurvivalLikelihoodMode::Latent => Ok(fit_latent_survival_model(
2611 build_latent_survival_request(candidate)?,
2612 )
2613 .map_err(|e| format!("latent survival fit failed: {e}"))?
2614 .fit
2615 .reml_score),
2616 SurvivalLikelihoodMode::LatentBinary => Ok(fit_latent_binary_model(
2617 build_latent_binary_request(candidate)?,
2618 )
2619 .map_err(|e| format!("latent binary fit failed: {e}"))?
2620 .fit
2621 .reml_score),
2622 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => {
2623 unreachable!()
2624 }
2625 },
2626 )?
2627 } else {
2628 baseline_cfg
2629 };
2630
2631 let request = match survival_mode {
2632 SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => {
2633 if config.noise_offset_column.is_some() {
2634 return Err(
2635 "noise_offset_column is supported only for survival location-scale or marginal-slope"
2636 .to_string(),
2637 );
2638 }
2639 let weibull_seed = if survival_mode == SurvivalLikelihoodMode::Weibull
2640 && effective_timewiggle.is_none()
2641 {
2642 let scale = config
2643 .baseline_scale
2644 .unwrap_or_else(|| positive_survival_time_seed(&age_exit));
2645 let shape = config.baseline_shape.unwrap_or(1.0);
2646 if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
2647 return Err(
2648 "weibull survival fit requires finite positive baseline_scale and baseline_shape"
2649 .to_string(),
2650 );
2651 }
2652 Some((scale, shape))
2653 } else {
2654 None
2655 };
2656 FitRequest::SurvivalTransformation(SurvivalTransformationFitRequest {
2657 data: data.values.view(),
2658 spec: SurvivalTransformationTermSpec {
2659 age_entry: age_entry.clone(),
2660 age_exit: age_exit.clone(),
2661 event_target: event.mapv(|value| if value >= 0.5 { 1 } else { 0 }),
2662 weights: weights.clone(),
2663 covariate_spec: termspec.clone(),
2664 covariate_offset: threshold_offset.clone(),
2665 baseline_cfg,
2666 likelihood_mode: survival_mode,
2667 time_anchor,
2668 time_build: time_build.clone(),
2669 timewiggle: effective_timewiggle.clone(),
2670 weibull_seed,
2671 ridge_lambda: config.ridge_lambda,
2672 },
2673 })
2674 }
2675 SurvivalLikelihoodMode::LocationScale => {
2676 FitRequest::SurvivalLocationScale(build_location_scale_request(&baseline_cfg)?)
2677 }
2678 SurvivalLikelihoodMode::MarginalSlope => {
2679 FitRequest::SurvivalMarginalSlope(build_marginal_slope_request(&baseline_cfg)?)
2680 }
2681 SurvivalLikelihoodMode::Latent => {
2682 FitRequest::LatentSurvival(build_latent_survival_request(&baseline_cfg)?)
2683 }
2684 SurvivalLikelihoodMode::LatentBinary => {
2685 FitRequest::LatentBinary(build_latent_binary_request(&baseline_cfg)?)
2686 }
2687 };
2688
2689 Ok(MaterializedModel {
2690 request,
2691 inference_notes,
2692 })
2693}
2694
2695fn materialize_transformation_normal<'a>(
2696 parsed: &ParsedFormula,
2697 data: &'a Dataset,
2698 col_map: &HashMap<String, usize>,
2699 config: &FitConfig,
2700) -> Result<MaterializedModel<'a>, String> {
2701 if parsed.linkspec.is_some() {
2702 return Err("link(...) is not supported for the transformation-normal family".to_string());
2703 }
2704 if parsed.linkwiggle.is_some() {
2705 return Err(
2706 "linkwiggle(...) is not supported for the transformation-normal family".to_string(),
2707 );
2708 }
2709 if config.noise_offset_column.is_some() {
2710 return Err(
2711 "noise_offset_column is not supported for transformation-normal models".to_string(),
2712 );
2713 }
2714 if config.frailty.is_some() {
2715 return Err("frailty is not supported for transformation-normal models".to_string());
2716 }
2717
2718 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
2719 let y = data.values.column(y_col).to_owned();
2720 let mut inference_notes = Vec::new();
2721
2722 let policy = resolved_resource_policy(config);
2723 let mut covariate_spec =
2724 build_termspec(&parsed.terms, data, col_map, &mut inference_notes, &policy)?;
2725 if config.scale_dimensions {
2726 enable_scale_dimensions(&mut covariate_spec);
2727 }
2728
2729 let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
2730 let offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
2731
2732 Ok(MaterializedModel {
2733 request: FitRequest::TransformationNormal(TransformationNormalFitRequest {
2734 data: data.values.view(),
2735 response: y,
2736 weights,
2737 offset,
2738 covariate_spec,
2739 config: TransformationNormalConfig::default(),
2740 options: BlockwiseFitOptions::default(),
2741 kappa_options: SpatialLengthScaleOptimizationOptions::default(),
2742 warm_start: None,
2743 }),
2744 inference_notes,
2745 })
2746}
2747
2748fn materialize_location_scale<'a>(
2749 parsed: &ParsedFormula,
2750 data: &'a Dataset,
2751 col_map: &HashMap<String, usize>,
2752 config: &FitConfig,
2753) -> Result<MaterializedModel<'a>, String> {
2754 let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
2755 let y = data.values.column(y_col).to_owned();
2756 let mut inference_notes = Vec::new();
2757
2758 let noise_formula = config
2759 .noise_formula
2760 .as_deref()
2761 .ok_or_else(|| "noise_formula is required for location-scale models".to_string())?;
2762 let noise_parsed = parse_formula(&format!("{} ~ {noise_formula}", parsed.response))?;
2763
2764 let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
2765 let family = resolve_family(config.family.as_deref(), link_choice.as_ref(), y.view())?;
2766
2767 let effective_linkwiggle =
2768 effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
2769
2770 let policy = resolved_resource_policy(config);
2771 let mut meanspec = build_termspec(&parsed.terms, data, col_map, &mut inference_notes, &policy)?;
2772 let mut log_sigmaspec = build_termspec(
2773 &noise_parsed.terms,
2774 data,
2775 col_map,
2776 &mut inference_notes,
2777 &policy,
2778 )?;
2779 if config.scale_dimensions {
2780 enable_scale_dimensions(&mut meanspec);
2781 enable_scale_dimensions(&mut log_sigmaspec);
2782 }
2783
2784 let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
2785 let mean_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
2786 let noise_offset = resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
2787 let kappa_options = SpatialLengthScaleOptimizationOptions::default();
2788 let options = BlockwiseFitOptions::default();
2789
2790 let wiggle_cfg = effective_linkwiggle.map(|cfg| LinkWiggleConfig {
2791 degree: cfg.degree,
2792 num_internal_knots: cfg.num_internal_knots,
2793 penalty_orders: cfg.penalty_orders,
2794 double_penalty: cfg.double_penalty,
2795 });
2796
2797 if matches!(family, LikelihoodFamily::BinomialLatentCLogLog) {
2798 return Err(
2799 "latent-cloglog-binomial is not implemented for location-scale fitting".to_string(),
2800 );
2801 }
2802
2803 if is_binomial_family(family) {
2804 let link_kind = link_choice
2805 .as_ref()
2806 .map(|c| InverseLink::Standard(c.link))
2807 .unwrap_or(InverseLink::Standard(LinkFunction::Logit));
2808 Ok(MaterializedModel {
2809 request: FitRequest::BinomialLocationScale(BinomialLocationScaleFitRequest {
2810 data: data.values.view(),
2811 spec: BinomialLocationScaleTermSpec {
2812 y,
2813 weights,
2814 link_kind,
2815 thresholdspec: meanspec,
2816 log_sigmaspec,
2817 threshold_offset: mean_offset,
2818 log_sigma_offset: noise_offset,
2819 },
2820 wiggle: wiggle_cfg,
2821 options,
2822 kappa_options,
2823 }),
2824 inference_notes,
2825 })
2826 } else {
2827 Ok(MaterializedModel {
2828 request: FitRequest::GaussianLocationScale(GaussianLocationScaleFitRequest {
2829 data: data.values.view(),
2830 spec: GaussianLocationScaleTermSpec {
2831 y,
2832 weights,
2833 meanspec,
2834 log_sigmaspec,
2835 mean_offset,
2836 log_sigma_offset: noise_offset,
2837 },
2838 wiggle: wiggle_cfg,
2839 options,
2840 kappa_options,
2841 }),
2842 inference_notes,
2843 })
2844 }
2845}
2846
2847#[cfg(test)]
2848mod tests {
2849 use super::*;
2850 use crate::basis::{DuchonNullspaceOrder, minimum_duchon_power_for_operator_penalties};
2851 use crate::inference::data::load_dataset_projected;
2852 use crate::inference::formula_dsl::{
2853 default_linkwiggle_formulaspec, parse_linkwiggle_formulaspec,
2854 };
2855 use crate::inference::model::{ColumnKindTag, DataSchema, SchemaColumn};
2856 use crate::solver::outer_strategy::{HessianSource, OuterPlan, OuterResult, Solver};
2857 use ndarray::Array2;
2858 use std::fs;
2859 use tempfile::tempdir;
2860
2861 fn load_survival_dataset() -> crate::inference::data::EncodedDataset {
2862 let td = tempdir().expect("tempdir");
2863 let data_path = td.path().join("survival.csv");
2864 fs::write(
2865 &data_path,
2866 "entry,exit,event,x,z\n0.0,1.0,1,0.2,-0.4\n0.3,1.6,0,-0.1,0.6\n",
2867 )
2868 .expect("write survival csv");
2869 load_dataset_projected(
2870 &data_path,
2871 &[
2872 "entry".to_string(),
2873 "exit".to_string(),
2874 "event".to_string(),
2875 "x".to_string(),
2876 "z".to_string(),
2877 ],
2878 )
2879 .expect("load survival dataset")
2880 }
2881
2882 #[test]
2883 fn survival_marginal_slope_materialize_rejects_z_column_in_main_formula() {
2884 let data = load_survival_dataset();
2885 let mut config = FitConfig::default();
2886 config.survival_likelihood = "marginal-slope".to_string();
2887 config.logslope_formula = Some("1".to_string());
2888 config.z_column = Some("z".to_string());
2889
2890 let err = materialize("Surv(entry, exit, event) ~ x + z", &data, &config)
2891 .err()
2892 .expect("main formula should reject z-column reuse");
2893
2894 assert!(err.contains("survival marginal-slope reserves z column 'z'"));
2895 assert!(err.contains("main formula"));
2896 }
2897
2898 #[test]
2899 fn survival_marginal_slope_materialize_rejects_z_column_in_logslope_formula() {
2900 let data = load_survival_dataset();
2901 let mut config = FitConfig::default();
2902 config.survival_likelihood = "marginal-slope".to_string();
2903 config.logslope_formula = Some("1 + z".to_string());
2904 config.z_column = Some("z".to_string());
2905
2906 let err = materialize("Surv(entry, exit, event) ~ x", &data, &config)
2907 .err()
2908 .expect("logslope formula should reject z-column reuse");
2909
2910 assert!(err.contains("survival marginal-slope reserves z column 'z'"));
2911 assert!(err.contains("logslope_formula"));
2912 }
2913
2914 #[test]
2915 fn survival_marginal_slope_materialize_rejects_z_column_when_logslope_defaults_to_main_spec() {
2916 let data = load_survival_dataset();
2917 let mut config = FitConfig::default();
2918 config.survival_likelihood = "marginal-slope".to_string();
2919 config.z_column = Some("z".to_string());
2920
2921 let err = materialize("Surv(entry, exit, event) ~ x + z", &data, &config)
2922 .err()
2923 .expect("defaulted logslope spec should still reject z-column reuse");
2924
2925 assert!(err.contains("survival marginal-slope reserves z column 'z'"));
2926 assert!(err.contains("main formula"));
2927 }
2928
2929 fn workflow_test_dataset() -> Dataset {
2930 Dataset {
2931 headers: vec![
2932 "age_entry".to_string(),
2933 "age_exit".to_string(),
2934 "event".to_string(),
2935 "bmi".to_string(),
2936 "z".to_string(),
2937 ],
2938 values: Array2::from_shape_vec(
2939 (4, 5),
2940 vec![
2941 40.0, 43.0, 1.0, 22.0, -1.0, 41.0, 46.0, 0.0, 24.0, -0.2, 42.0, 47.0, 1.0,
2942 27.0, 0.3, 44.0, 49.0, 0.0, 29.0, 1.2,
2943 ],
2944 )
2945 .expect("workflow test data shape"),
2946 schema: DataSchema {
2947 columns: vec![
2948 SchemaColumn {
2949 name: "age_entry".to_string(),
2950 kind: ColumnKindTag::Continuous,
2951 levels: vec![],
2952 },
2953 SchemaColumn {
2954 name: "age_exit".to_string(),
2955 kind: ColumnKindTag::Continuous,
2956 levels: vec![],
2957 },
2958 SchemaColumn {
2959 name: "event".to_string(),
2960 kind: ColumnKindTag::Binary,
2961 levels: vec![],
2962 },
2963 SchemaColumn {
2964 name: "bmi".to_string(),
2965 kind: ColumnKindTag::Continuous,
2966 levels: vec![],
2967 },
2968 SchemaColumn {
2969 name: "z".to_string(),
2970 kind: ColumnKindTag::Continuous,
2971 levels: vec![],
2972 },
2973 ],
2974 },
2975 column_kinds: vec![
2976 ColumnKindTag::Continuous,
2977 ColumnKindTag::Continuous,
2978 ColumnKindTag::Binary,
2979 ColumnKindTag::Continuous,
2980 ColumnKindTag::Continuous,
2981 ],
2982 }
2983 }
2984
2985 fn workflow_test_outer_result(converged: bool, rho: Array1<f64>) -> OuterResult {
2986 OuterResult {
2987 rho,
2988 final_value: 1.25,
2989 iterations: 7,
2990 final_grad_norm: 0.5,
2991 final_gradient: None,
2992 final_hessian: None,
2993 converged,
2994 plan_used: OuterPlan {
2995 solver: Solver::Bfgs,
2996 hessian_source: HessianSource::BfgsApprox,
2997 },
2998 operator_trust_radius: None,
2999 operator_stop_reason: None,
3000 }
3001 }
3002
3003 #[test]
3004 fn workflow_survival_marginal_slope_routes_logslope_linkwiggle_into_score_warp_only() {
3005 let data = workflow_test_dataset();
3006 let config = FitConfig {
3007 survival_likelihood: "marginal-slope".to_string(),
3008 logslope_formula: Some(
3009 "1 + linkwiggle(degree=5, internal_knots=7, penalty_order=\"2,3\")".to_string(),
3010 ),
3011 z_column: Some("z".to_string()),
3012 ..FitConfig::default()
3013 };
3014 let materialized = materialize(
3015 "Surv(age_entry, age_exit, event) ~ s(bmi) + linkwiggle(degree=4, internal_knots=9, penalty_order=\"1\")",
3016 &data,
3017 &config,
3018 )
3019 .expect("workflow materialization should succeed");
3020
3021 let MaterializedModel {
3022 request,
3023 inference_notes,
3024 } = materialized;
3025 let FitRequest::SurvivalMarginalSlope(request) = request else {
3026 panic!("expected survival marginal-slope request");
3027 };
3028
3029 let link_dev = request.spec.link_dev.expect("main-formula link-dev");
3030 let score_warp = request.spec.score_warp.expect("logslope score-warp");
3031 assert_eq!(link_dev.degree, 4);
3032 assert_eq!(link_dev.num_internal_knots, 9);
3033 assert_eq!(link_dev.penalty_order, 1);
3034 assert_eq!(link_dev.penalty_orders, vec![1]);
3035 assert_eq!(score_warp.degree, 5);
3036 assert_eq!(score_warp.num_internal_knots, 7);
3037 assert_eq!(score_warp.penalty_order, 3);
3038 assert_eq!(score_warp.penalty_orders, vec![2, 3]);
3039 assert!(
3040 inference_notes
3041 .iter()
3042 .any(|note| note.contains("link-deviation block")),
3043 "workflow notes should mention main-formula linkwiggle routing"
3044 );
3045 assert!(
3046 inference_notes
3047 .iter()
3048 .any(|note| note.contains("score-warp block")),
3049 "workflow notes should mention logslope_formula linkwiggle routing"
3050 );
3051 }
3052
3053 #[test]
3054 fn materialize_routes_bernoulli_marginal_slope_when_logslope_and_z_are_set() {
3055 let data = workflow_test_dataset();
3056 let config = FitConfig {
3057 logslope_formula: Some("1".to_string()),
3058 z_column: Some("z".to_string()),
3059 ..FitConfig::default()
3060 };
3061 let materialized = materialize("event ~ bmi", &data, &config)
3062 .expect("Bernoulli marginal-slope materialization should succeed");
3063 assert!(matches!(
3064 materialized.request,
3065 FitRequest::BernoulliMarginalSlope(_)
3066 ));
3067 }
3068
3069 #[test]
3070 fn linkwiggle_defaults_are_consistent_across_formula_and_runtime() {
3071 let parsed = parse_linkwiggle_formulaspec(&Default::default(), "linkwiggle()")
3072 .expect("default linkwiggle should parse");
3073 let formula_default = default_linkwiggle_formulaspec();
3074 let runtime_default = DeviationBlockConfig::default();
3075 assert_eq!(parsed.degree, formula_default.degree);
3076 assert_eq!(
3077 parsed.num_internal_knots,
3078 formula_default.num_internal_knots
3079 );
3080 assert_eq!(parsed.penalty_orders, formula_default.penalty_orders);
3081 assert_eq!(parsed.double_penalty, formula_default.double_penalty);
3082 assert_eq!(runtime_default.degree, formula_default.degree);
3083 assert_eq!(
3084 runtime_default.num_internal_knots,
3085 formula_default.num_internal_knots
3086 );
3087 assert_eq!(
3088 runtime_default.penalty_orders,
3089 formula_default.penalty_orders
3090 );
3091 assert_eq!(
3092 runtime_default.double_penalty,
3093 formula_default.double_penalty
3094 );
3095 }
3096
3097 #[test]
3098 fn survival_marginal_slope_accepts_explicit_probit_link() {
3099 let data = workflow_test_dataset();
3100 let config = FitConfig {
3101 survival_likelihood: "marginal-slope".to_string(),
3102 logslope_formula: Some("1".to_string()),
3103 z_column: Some("z".to_string()),
3104 ..FitConfig::default()
3105 };
3106 let ok = materialize(
3107 "Surv(age_entry, age_exit, event) ~ bmi + link(type=probit)",
3108 &data,
3109 &config,
3110 );
3111 assert!(ok.is_ok(), "explicit probit should be accepted");
3112
3113 let err = match materialize(
3114 "Surv(age_entry, age_exit, event) ~ bmi + link(type=logit)",
3115 &data,
3116 &config,
3117 ) {
3118 Ok(_) => panic!("non-probit link should be rejected"),
3119 Err(err) => err,
3120 };
3121 assert!(err.contains("only link(type=probit)"));
3122 }
3123
3124 #[test]
3125 fn high_dimensional_duchon_default_power_is_admissible() {
3126 let dim = 16;
3127 let power = minimum_duchon_power_for_operator_penalties(dim, DuchonNullspaceOrder::Zero, 2);
3128 assert!(2 * (1 + power) > dim + 2);
3129 }
3130
3131 #[test]
3132 fn survival_location_scale_wiggle_rejects_unsupported_inverse_link() {
3133 let data = workflow_test_dataset();
3134 let materialized = materialize(
3135 "Surv(age_entry, age_exit, event) ~ bmi + linkwiggle(degree=4, internal_knots=3, penalty_order=\"1\")",
3136 &data,
3137 &FitConfig::default(),
3138 )
3139 .expect("workflow materialization should succeed");
3140
3141 let MaterializedModel { request, .. } = materialized;
3142 let FitRequest::SurvivalLocationScale(mut request) = request else {
3143 panic!("expected survival location-scale request");
3144 };
3145 request.spec.inverse_link = InverseLink::Sas(
3146 state_from_sasspec(SasLinkSpec {
3147 initial_epsilon: 0.1,
3148 initial_log_delta: 0.0,
3149 })
3150 .expect("valid SAS state"),
3151 );
3152 request.optimize_inverse_link = false;
3153
3154 let err = match fit_survival_location_scale_model(request) {
3155 Ok(_) => panic!("survival link wiggle should reject unsupported inverse links"),
3156 Err(e) => e,
3157 };
3158
3159 assert!(err.contains("survival link wiggle"));
3160 assert!(err.contains("does not support"));
3161 }
3162
3163 #[test]
3164 fn survival_inverse_link_result_requires_convergence() {
3165 let err = recover_converged_survival_inverse_link(
3166 workflow_test_outer_result(false, Array1::from_vec(vec![0.1, -0.2])),
3167 "survival inverse-link optimization (SAS, dim=2)",
3168 |_| Some(InverseLink::Standard(LinkFunction::Logit)),
3169 )
3170 .expect_err("non-converged inverse-link search should fail");
3171
3172 assert!(err.contains("did not converge"));
3173 assert!(err.contains("final_objective"));
3174 }
3175
3176 #[test]
3177 fn survival_inverse_link_result_requires_recoverable_state() {
3178 let err = recover_converged_survival_inverse_link(
3179 workflow_test_outer_result(true, Array1::from_vec(vec![9.0, 8.0])),
3180 "survival inverse-link optimization (mixture, dim=2)",
3181 |_| None,
3182 )
3183 .expect_err("unrecoverable inverse-link state should fail");
3184
3185 assert!(err.contains("produced an invalid inverse-link state"));
3186 assert!(err.contains("9.0"));
3187 }
3188}