1use super::*;
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15pub enum TimeBlockMonotonicity {
16 EnforcedByCoordinateCone,
21 EnforcedByRowConstraint,
28 StructuralISpline,
35}
36
37impl TimeBlockMonotonicity {
38 #[inline]
43 pub fn is_coordinate_cone(self) -> bool {
44 matches!(
45 self,
46 Self::EnforcedByCoordinateCone | Self::StructuralISpline
47 )
48 }
49
50 #[inline]
54 pub fn requires_row_constraints(self) -> bool {
55 matches!(self, Self::EnforcedByRowConstraint)
56 }
57}
58
59#[derive(Clone)]
60pub struct TimeBlockInput {
61 pub design_entry: DesignMatrix,
62 pub design_exit: DesignMatrix,
63 pub design_derivative_exit: DesignMatrix,
64 pub offset_entry: Array1<f64>,
65 pub offset_exit: Array1<f64>,
66 pub derivative_offset_exit: Array1<f64>,
67 pub time_monotonicity: TimeBlockMonotonicity,
71 pub penalties: Vec<Array2<f64>>,
72 pub nullspace_dims: Vec<usize>,
74 pub initial_log_lambdas: Option<Array1<f64>>,
75 pub initial_beta: Option<Array1<f64>>,
76}
77
78#[derive(Clone)]
90pub struct TimeDependentCovariateBlockInput {
91 pub design_covariates: DesignMatrix,
93 pub time_basis_entry: Array2<f64>,
95 pub time_basis_exit: Array2<f64>,
97 pub time_basis_derivative_exit: Array2<f64>,
99 pub penalties: Vec<PenaltyMatrix>,
101 pub initial_log_lambdas: Option<Array1<f64>>,
102 pub initial_beta: Option<Array1<f64>>,
103 pub offset: Array1<f64>,
104}
105
106#[derive(Clone)]
109pub enum CovariateBlockKind {
110 Static(ParameterBlockInput),
111 TimeVarying(TimeDependentCovariateBlockInput),
112}
113
114#[derive(Clone)]
115pub struct LinkWiggleBlockInput {
116 pub design: DesignMatrix,
117 pub knots: Array1<f64>,
118 pub degree: usize,
119 pub penalties: Vec<gam_terms::penalty_spec::PenaltySpec>,
120 pub nullspace_dims: Vec<usize>,
122 pub initial_log_lambdas: Option<Array1<f64>>,
123 pub initial_beta: Option<Array1<f64>>,
124}
125
126#[derive(Clone)]
127pub struct TimeWiggleBlockInput {
128 pub knots: Array1<f64>,
129 pub degree: usize,
130 pub ncols: usize,
131}
132
133#[derive(Clone)]
134pub(crate) struct SurvivalLocationScaleSpec {
135 pub age_entry: Array1<f64>,
136 pub age_exit: Array1<f64>,
137 pub event_target: Array1<f64>,
138 pub weights: Array1<f64>,
139 pub inverse_link: InverseLink,
140 pub derivative_guard: f64,
141 pub max_iter: usize,
142 pub tol: f64,
143 pub time_block: TimeBlockInput,
144 pub threshold_block: CovariateBlockKind,
145 pub log_sigma_block: CovariateBlockKind,
146 pub timewiggle_block: Option<TimeWiggleBlockInput>,
147 pub linkwiggle_block: Option<LinkWiggleBlockInput>,
148 pub cache_session: Option<std::sync::Arc<gam_runtime::warm_start::Session>>,
151 pub cache_mirror_sessions: Vec<std::sync::Arc<gam_runtime::warm_start::Session>>,
154}
155
156#[derive(Clone)]
157pub enum SurvivalCovariateTermBlockTemplate {
158 Static,
159 TimeVarying {
160 time_basis_entry: Array2<f64>,
161 time_basis_exit: Array2<f64>,
162 time_basis_derivative_exit: Array2<f64>,
163 time_penalties: Vec<Array2<f64>>,
164 },
165}
166
167#[derive(Clone)]
168pub struct SurvivalLocationScaleTermSpec {
169 pub age_entry: Array1<f64>,
170 pub age_exit: Array1<f64>,
171 pub event_target: Array1<f64>,
172 pub weights: Array1<f64>,
173 pub inverse_link: InverseLink,
174 pub derivative_guard: f64,
177 pub max_iter: usize,
178 pub tol: f64,
179 pub time_block: TimeBlockInput,
180 pub thresholdspec: TermCollectionSpec,
181 pub log_sigmaspec: TermCollectionSpec,
182 pub threshold_offset: Array1<f64>,
183 pub log_sigma_offset: Array1<f64>,
184 pub threshold_template: SurvivalCovariateTermBlockTemplate,
185 pub log_sigma_template: SurvivalCovariateTermBlockTemplate,
186 pub timewiggle_block: Option<TimeWiggleBlockInput>,
187 pub linkwiggle_block: Option<LinkWiggleBlockInput>,
188 pub initial_threshold_log_lambdas: Option<Array1<f64>>,
194 pub initial_log_sigma_log_lambdas: Option<Array1<f64>>,
197 pub cache_session: Option<std::sync::Arc<gam_runtime::warm_start::Session>>,
200 pub cache_mirror_sessions: Vec<std::sync::Arc<gam_runtime::warm_start::Session>>,
203}
204
205pub const DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD: f64 = 1e-6;
206
207pub struct SurvivalLocationScaleTermFitResult {
208 pub fit: UnifiedFitResult,
209 pub resolved_thresholdspec: TermCollectionSpec,
210 pub resolved_log_sigmaspec: TermCollectionSpec,
211 pub threshold_design: TermCollectionDesign,
212 pub log_sigma_design: TermCollectionDesign,
213 pub baseline_offset_residuals: OffsetChannelResiduals,
218 pub baseline_offset_curvatures: OffsetChannelCurvatures,
223 pub link_param_data_fit_gradient: Option<Array1<f64>>,
229}
230
231pub struct SurvivalLocationScaleFitResultParts {
234 pub beta_time: Array1<f64>,
235 pub beta_threshold: Array1<f64>,
236 pub beta_log_sigma: Array1<f64>,
237 pub beta_link_wiggle: Option<Array1<f64>>,
238 pub link_wiggle_knots: Option<Array1<f64>>,
239 pub link_wiggle_degree: Option<usize>,
240 pub lambdas_time: Array1<f64>,
241 pub lambdas_threshold: Array1<f64>,
242 pub lambdas_log_sigma: Array1<f64>,
243 pub lambdas_linkwiggle: Option<Array1<f64>>,
244 pub log_likelihood: f64,
245 pub reml_score: f64,
246 pub stable_penalty_term: f64,
247 pub penalized_objective: f64,
248 pub used_device: bool,
252 pub outer_iterations: usize,
253 pub outer_gradient_norm: Option<f64>,
256 pub outer_converged: bool,
257 pub covariance_conditional: Option<Array2<f64>>,
258 pub geometry: Option<FitGeometry>,
259}
260
261#[derive(Clone, Copy)]
262pub(crate) struct SurvivalLambdaLayout {
263 pub(crate) k_time: usize,
264 pub(crate) k_threshold: usize,
265 pub(crate) k_log_sigma: usize,
266 pub(crate) k_wiggle: usize,
267}
268
269impl SurvivalLambdaLayout {
270 pub(crate) fn new(
271 k_time: usize,
272 k_threshold: usize,
273 k_log_sigma: usize,
274 k_wiggle: usize,
275 ) -> Self {
276 Self {
277 k_time,
278 k_threshold,
279 k_log_sigma,
280 k_wiggle,
281 }
282 }
283
284 pub(crate) fn total(&self) -> usize {
285 self.k_time + self.k_threshold + self.k_log_sigma + self.k_wiggle
286 }
287
288 pub(crate) fn time_range(&self) -> std::ops::Range<usize> {
289 0..self.k_time
290 }
291
292 pub(crate) fn threshold_range(&self) -> std::ops::Range<usize> {
293 self.k_time..self.k_time + self.k_threshold
294 }
295
296 pub(crate) fn log_sigma_range(&self) -> std::ops::Range<usize> {
297 self.k_time + self.k_threshold..self.k_time + self.k_threshold + self.k_log_sigma
298 }
299
300 pub(crate) fn wiggle_range(&self) -> std::ops::Range<usize> {
301 self.k_time + self.k_threshold + self.k_log_sigma..self.total()
302 }
303
304 pub(crate) fn validate_rho(&self, rho: &Array1<f64>, label: &str) -> Result<(), String> {
305 if rho.len() != self.total() {
306 return Err(SurvivalLocationScaleError::DimensionMismatch {
307 reason: format!(
308 "{label} rho length mismatch: got {}, expected {}",
309 rho.len(),
310 self.total()
311 ),
312 }
313 .into());
314 }
315 Ok::<(), _>(())
316 }
317
318 pub(crate) fn time_from(&self, rho: &Array1<f64>) -> Array1<f64> {
319 let range = self.time_range();
320 rho.slice(s![range.start..range.end]).to_owned()
321 }
322
323 pub(crate) fn threshold_from(&self, rho: &Array1<f64>) -> Array1<f64> {
324 let range = self.threshold_range();
325 rho.slice(s![range.start..range.end]).to_owned()
326 }
327
328 pub(crate) fn log_sigma_from(&self, rho: &Array1<f64>) -> Array1<f64> {
329 let range = self.log_sigma_range();
330 rho.slice(s![range.start..range.end]).to_owned()
331 }
332
333 pub(crate) fn wiggle_from(&self, rho: &Array1<f64>) -> Option<Array1<f64>> {
334 if self.k_wiggle == 0 {
335 None
336 } else {
337 let range = self.wiggle_range();
338 Some(rho.slice(s![range.start..range.end]).to_owned())
339 }
340 }
341}
342
343pub fn survival_fit_from_parts(
345 parts: SurvivalLocationScaleFitResultParts,
346) -> Result<UnifiedFitResult, String> {
347 let SurvivalLocationScaleFitResultParts {
348 beta_time,
349 beta_threshold,
350 beta_log_sigma,
351 beta_link_wiggle,
352 link_wiggle_knots,
353 link_wiggle_degree,
354 lambdas_time,
355 lambdas_threshold,
356 lambdas_log_sigma,
357 lambdas_linkwiggle,
358 log_likelihood,
359 reml_score,
360 stable_penalty_term,
361 penalized_objective,
362 used_device,
363 outer_iterations,
364 outer_gradient_norm,
365 outer_converged,
366 covariance_conditional,
367 geometry,
368 } = parts;
369
370 validate_all_finite_estimation("survival_fit.beta_time", beta_time.iter().copied())
372 .map_err(|e| e.to_string())?;
373 validate_all_finite_estimation(
374 "survival_fit.beta_threshold",
375 beta_threshold.iter().copied(),
376 )
377 .map_err(|e| e.to_string())?;
378 validate_all_finite_estimation(
379 "survival_fit.beta_log_sigma",
380 beta_log_sigma.iter().copied(),
381 )
382 .map_err(|e| e.to_string())?;
383 if let Some(beta_wiggle) = beta_link_wiggle.as_ref() {
384 validate_all_finite_estimation(
385 "survival_fit.beta_link_wiggle",
386 beta_wiggle.iter().copied(),
387 )
388 .map_err(|e| e.to_string())?;
389 let knots = link_wiggle_knots.as_ref().ok_or_else(|| {
390 "survival_fit.beta_link_wiggle requires link_wiggle_knots".to_string()
391 })?;
392 validate_all_finite_estimation("survival_fit.link_wiggle_knots", knots.iter().copied())
393 .map_err(|e| e.to_string())?;
394 if link_wiggle_degree.is_none() {
395 return Err(SurvivalLocationScaleError::InvalidConfiguration {
396 reason: "survival_fit.beta_link_wiggle requires link_wiggle_degree".to_string(),
397 }
398 .into());
399 }
400 } else if link_wiggle_knots.is_some() || link_wiggle_degree.is_some() {
401 return Err(SurvivalLocationScaleError::InvalidConfiguration {
402 reason: "survival_fit link-wiggle metadata requires beta_link_wiggle coefficients"
403 .to_string(),
404 }
405 .into());
406 }
407 validate_all_finite_estimation("survival_fit.lambdas_time", lambdas_time.iter().copied())
408 .map_err(|e| e.to_string())?;
409 validate_all_finite_estimation(
410 "survival_fit.lambdas_threshold",
411 lambdas_threshold.iter().copied(),
412 )
413 .map_err(|e| e.to_string())?;
414 validate_all_finite_estimation(
415 "survival_fit.lambdas_log_sigma",
416 lambdas_log_sigma.iter().copied(),
417 )
418 .map_err(|e| e.to_string())?;
419 if lambdas_time.len() > beta_time.len() {
427 return Err(SurvivalLocationScaleError::DimensionMismatch {
428 reason: format!(
429 "survival_fit.lambdas_time has {} entries but beta_time has only {} \
430 coefficients; each lambda corresponds to a penalty term on this block",
431 lambdas_time.len(),
432 beta_time.len()
433 ),
434 }
435 .into());
436 }
437 if lambdas_threshold.len() > beta_threshold.len() {
438 return Err(SurvivalLocationScaleError::DimensionMismatch {
439 reason: format!(
440 "survival_fit.lambdas_threshold has {} entries but beta_threshold has only {} \
441 coefficients; each lambda corresponds to a penalty term on this block",
442 lambdas_threshold.len(),
443 beta_threshold.len()
444 ),
445 }
446 .into());
447 }
448 if lambdas_log_sigma.len() > beta_log_sigma.len() {
449 return Err(SurvivalLocationScaleError::DimensionMismatch {
450 reason: format!(
451 "survival_fit.lambdas_log_sigma has {} entries but beta_log_sigma has only {} \
452 coefficients; each lambda corresponds to a penalty term on this block",
453 lambdas_log_sigma.len(),
454 beta_log_sigma.len()
455 ),
456 }
457 .into());
458 }
459 if let Some(lambdas_wiggle) = lambdas_linkwiggle.as_ref() {
460 if beta_link_wiggle.is_none() {
461 return Err(SurvivalLocationScaleError::InvalidConfiguration {
462 reason: "survival_fit.lambdas_linkwiggle requires beta_link_wiggle".to_string(),
463 }
464 .into());
465 }
466 validate_all_finite_estimation(
467 "survival_fit.lambdas_linkwiggle",
468 lambdas_wiggle.iter().copied(),
469 )
470 .map_err(|e| e.to_string())?;
471 let wiggle_len = beta_link_wiggle.as_ref().map_or(0, |beta| beta.len());
472 if lambdas_wiggle.len() > wiggle_len {
473 return Err(SurvivalLocationScaleError::DimensionMismatch {
474 reason: format!(
475 "survival_fit.lambdas_linkwiggle has {} entries but beta_link_wiggle has \
476 only {} coefficients; each lambda corresponds to a penalty term on this block",
477 lambdas_wiggle.len(),
478 wiggle_len
479 ),
480 }
481 .into());
482 }
483 }
484 ensure_finite_scalar_estimation("survival_fit.log_likelihood", log_likelihood)
485 .map_err(|e| e.to_string())?;
486 ensure_finite_scalar_estimation("survival_fit.reml_score", reml_score)
487 .map_err(|e| e.to_string())?;
488 ensure_finite_scalar_estimation("survival_fit.stable_penalty_term", stable_penalty_term)
489 .map_err(|e| e.to_string())?;
490 ensure_finite_scalar_estimation("survival_fit.penalized_objective", penalized_objective)
491 .map_err(|e| e.to_string())?;
492 if let Some(g) = outer_gradient_norm {
493 ensure_finite_scalar_estimation("survival_fit.outer_gradient_norm", g)
494 .map_err(|e| e.to_string())?;
495 }
496
497 let total_p = beta_time.len()
498 + beta_threshold.len()
499 + beta_log_sigma.len()
500 + beta_link_wiggle.as_ref().map_or(0, |beta| beta.len());
501 if let Some(cov) = covariance_conditional.as_ref() {
502 validate_all_finite_estimation("survival_fit.covariance_conditional", cov.iter().copied())
503 .map_err(|e| e.to_string())?;
504 let (rows, cols) = cov.dim();
505 if rows != total_p || cols != total_p {
506 return Err(SurvivalLocationScaleError::InvalidConfiguration {
507 reason: format!(
508 "survival_fit.covariance_conditional must be {}x{}, got {}x{}",
509 total_p, total_p, rows, cols
510 ),
511 }
512 .into());
513 }
514 }
515 if let Some(geom) = geometry.as_ref() {
516 geom.validate_numeric_finiteness()
517 .map_err(|e| e.to_string())?;
518 let (rows, cols) = geom.penalized_hessian.dim();
519 if rows != total_p || cols != total_p {
520 return Err(SurvivalLocationScaleError::InvalidConfiguration {
521 reason: format!(
522 "survival_fit.geometry.penalized_hessian must be {}x{}, got {}x{}",
523 total_p, total_p, rows, cols
524 ),
525 }
526 .into());
527 }
528 if geom.working_weights.len() != geom.working_response.len() {
529 return Err(SurvivalLocationScaleError::DimensionMismatch {
530 reason: format!(
531 "survival_fit.geometry working length mismatch: weights={}, response={}",
532 geom.working_weights.len(),
533 geom.working_response.len()
534 ),
535 }
536 .into());
537 }
538 }
539
540 use crate::model_types::{BlockRole, FittedBlock, FittedLinkState, UnifiedFitResultParts};
542 let mut blocks = vec![
543 FittedBlock {
544 beta: beta_time.clone(),
545 role: BlockRole::Time,
546 edf: 0.0,
547 lambdas: lambdas_time.clone(),
548 },
549 FittedBlock {
550 beta: beta_threshold.clone(),
551 role: BlockRole::Threshold,
552 edf: 0.0,
553 lambdas: lambdas_threshold.clone(),
554 },
555 FittedBlock {
556 beta: beta_log_sigma.clone(),
557 role: BlockRole::Scale,
558 edf: 0.0,
559 lambdas: lambdas_log_sigma.clone(),
560 },
561 ];
562 if let Some(ref bw) = beta_link_wiggle {
563 blocks.push(FittedBlock {
564 beta: bw.clone(),
565 role: BlockRole::LinkWiggle,
566 edf: 0.0,
567 lambdas: lambdas_linkwiggle
568 .clone()
569 .unwrap_or_else(|| Array1::zeros(0)),
570 });
571 }
572 let all_lambdas: Vec<f64> = blocks
573 .iter()
574 .flat_map(|b| b.lambdas.iter().copied())
575 .collect();
576 let log_lambdas = Array1::from_vec(
577 all_lambdas
578 .iter()
579 .map(|&v| if v > 0.0 { v.ln() } else { f64::NEG_INFINITY })
580 .collect(),
581 );
582 let deviance = -2.0 * log_likelihood;
583 crate::model_types::UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
584 blocks,
585 log_lambdas,
586 lambdas: Array1::from_vec(all_lambdas),
587 likelihood_family: None,
588 likelihood_scale: gam_problem::LikelihoodScaleMetadata::Unspecified,
589 log_likelihood_normalization: gam_problem::LogLikelihoodNormalization::UserProvided,
590 log_likelihood,
591 deviance,
592 reml_score,
593 stable_penalty_term,
594 penalized_objective,
595 used_device,
596 outer_iterations,
597 outer_converged,
598 outer_gradient_norm,
599 standard_deviation: 1.0,
600 covariance_conditional,
601 covariance_corrected: None,
602 inference: None,
603 fitted_link: FittedLinkState::Standard(None),
604 geometry,
605 block_states: Vec::new(),
606 pirls_status: gam_solve::pirls::PirlsStatus::Converged,
607 max_abs_eta: 0.0,
608 constraint_kkt: None,
609 artifacts: crate::model_types::FitArtifacts {
610 pirls: None,
611 null_space_logdet: None,
612 null_space_dim: None,
613 survival_link_wiggle_knots: link_wiggle_knots,
614 survival_link_wiggle_degree: link_wiggle_degree,
615 criterion_certificate: None,
616 rho_posterior_certificate: None,
617 rho_posterior_escalation: None,
618 rho_covariance: None,
619 },
620 inner_cycles: 0,
621 })
622 .map_err(|e| e.to_string())
623}
624
625#[derive(Clone)]
626pub struct SurvivalLocationScalePredictInput {
627 pub x_time_exit: Array2<f64>,
628 pub eta_time_offset_exit: Array1<f64>,
629 pub time_wiggle_knots: Option<Array1<f64>>,
630 pub time_wiggle_degree: Option<usize>,
631 pub time_wiggle_ncols: usize,
632 pub x_threshold: DesignMatrix,
633 pub eta_threshold_offset: Array1<f64>,
634 pub x_log_sigma: DesignMatrix,
635 pub eta_log_sigma_offset: Array1<f64>,
636 pub x_link_wiggle: Option<DesignMatrix>,
637 pub link_wiggle_knots: Option<Array1<f64>>,
638 pub link_wiggle_degree: Option<usize>,
639 pub inverse_link: InverseLink,
640}
641
642#[derive(Clone, Debug)]
643pub struct SurvivalLocationScalePredictResult {
644 pub eta: Array1<f64>,
645 pub survival_prob: Array1<f64>,
646}
647
648#[derive(Clone)]
649pub struct SurvivalLocationScalePredictUncertaintyResult {
650 pub eta: Array1<f64>,
651 pub survival_prob: Array1<f64>,
652 pub eta_standard_error: Array1<f64>,
653 pub response_standard_error: Option<Array1<f64>>,
654}
655
656pub(crate) fn initial_log_lambdas<T>(
657 penalties: &[T],
658 rho0: Option<Array1<f64>>,
659) -> Result<Array1<f64>, String> {
660 let k = penalties.len();
661 let rho = rho0.unwrap_or_else(|| Array1::zeros(k));
662 if rho.len() != k {
663 return Err(SurvivalLocationScaleError::DimensionMismatch {
664 reason: format!(
665 "initial_log_lambdas mismatch: got {}, expected {k}",
666 rho.len()
667 ),
668 }
669 .into());
670 }
671 Ok(rho)
672}