Skip to main content

gam_models/survival/location_scale/
spec.rs

1use super::*;
2
3/// How a time block's parameterization enforces the derivative-guard
4/// monotonicity `q'(t) ≥ guard`.
5///
6/// The constraint set fed to the inner active-set / KKT machinery depends on
7/// the variant; consuming families dispatch on this to choose the right
8/// constraint shape and to refuse a mismatched parameterization (e.g.
9/// `survival_marginal_slope` cannot ride a coordinate-cone-only basis
10/// without re-introducing the phantom-multiplier bug it solved with the
11/// row-wise representation; `survival_location_scale` cannot ride a
12/// row-wise representation without making its reduced KKT system
13/// rank-deficient on the cone basis).
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15pub enum TimeBlockMonotonicity {
16    /// The time block's coefficients are constrained by a per-coordinate
17    /// cone `β_j ≥ 0` (with appropriate offsets handled by the family).
18    /// Used by location-scale / latent paths whose bases produce a
19    /// non-negative derivative whenever the cone holds.
20    EnforcedByCoordinateCone,
21    /// The time block's coefficients are constrained by row-wise
22    /// `D β + o ≥ guard` over every observation row; needed when the
23    /// basis admits negative-derivative directions that no coordinate
24    /// cone can encode without leaving phantom KKT multipliers when a
25    /// row binds. Used by `survival_marginal_slope` under the additive
26    /// base.
27    EnforcedByRowConstraint,
28    /// The base is a structurally-monotone parameterization (e.g.
29    /// `q'(t) = guard + I(t)·γ` with `γ ≥ 0`). Monotonicity holds
30    /// pointwise from the cone; the family treats this exactly as a
31    /// coordinate cone for constraint generation but the geometric
32    /// claim is stronger and is recorded here for diagnostics and for
33    /// future fast paths (e.g. skipping per-row validation).
34    StructuralISpline,
35}
36
37impl TimeBlockMonotonicity {
38    /// True when the variant can be enforced by a coordinate cone alone
39    /// (no row-wise constraints required). Both `EnforcedByCoordinateCone`
40    /// and `StructuralISpline` satisfy this; only `EnforcedByRowConstraint`
41    /// requires the row-wise `D β ≥ b` constraint matrix.
42    #[inline]
43    pub fn is_coordinate_cone(self) -> bool {
44        matches!(
45            self,
46            Self::EnforcedByCoordinateCone | Self::StructuralISpline
47        )
48    }
49
50    /// True when row-wise `D β + o ≥ guard` constraints must be emitted
51    /// for the inner active-set/KKT machinery to capture binding
52    /// multipliers correctly.
53    #[inline]
54    pub fn requires_row_constraints(self) -> bool {
55        matches!(self, Self::EnforcedByRowConstraint)
56    }
57}
58
59#[derive(Clone)]
60pub struct TimeBlockInput {
61    pub design_entry: DesignMatrix,
62    pub design_exit: DesignMatrix,
63    pub design_derivative_exit: DesignMatrix,
64    pub offset_entry: Array1<f64>,
65    pub offset_exit: Array1<f64>,
66    pub derivative_offset_exit: Array1<f64>,
67    /// How the time block enforces `q'(t) ≥ guard`. The consuming family
68    /// dispatches the constraint shape on this and refuses a mismatch
69    /// rather than silently producing a degenerate KKT system.
70    pub time_monotonicity: TimeBlockMonotonicity,
71    pub penalties: Vec<Array2<f64>>,
72    /// Structural nullspace dimension of each penalty matrix.
73    pub nullspace_dims: Vec<usize>,
74    pub initial_log_lambdas: Option<Array1<f64>>,
75    pub initial_beta: Option<Array1<f64>>,
76}
77
78/// A covariate block whose linear predictor depends on the survival time axis
79/// via a tensor product: covariate design (n x p_cov) ⊗ B-spline on log(time).
80///
81/// At row i the linear predictor evaluated at time t is
82///
83///   eta(t) = [ x_cov(i,:) ⊗ B_time(t) ] @ beta
84///
85/// where B_time(t) is a B-spline basis row evaluated at log(t).
86/// The entry and exit tensor designs are precomputed:
87///   X_entry[i,:] = x_cov(i,:) ⊗ B_time(t_entry_i)
88///   X_exit[i,:]  = x_cov(i,:) ⊗ B_time(t_exit_i)
89#[derive(Clone)]
90pub struct TimeDependentCovariateBlockInput {
91    /// Covariate design matrix (n x p_cov), same for all time points.
92    pub design_covariates: DesignMatrix,
93    /// B-spline time basis at entry times (n x p_time).
94    pub time_basis_entry: Array2<f64>,
95    /// B-spline time basis at exit times (n x p_time).
96    pub time_basis_exit: Array2<f64>,
97    /// Derivative of the time basis with respect to clock time at exit.
98    pub time_basis_derivative_exit: Array2<f64>,
99    /// Combined Kronecker penalties for the tensor product.
100    pub penalties: Vec<PenaltyMatrix>,
101    pub initial_log_lambdas: Option<Array1<f64>>,
102    pub initial_beta: Option<Array1<f64>>,
103    pub offset: Array1<f64>,
104}
105
106/// Whether a covariate block (threshold or log-sigma) is time-invariant or
107/// depends on the survival time axis via a tensor product.
108#[derive(Clone)]
109pub enum CovariateBlockKind {
110    Static(ParameterBlockInput),
111    TimeVarying(TimeDependentCovariateBlockInput),
112}
113
114#[derive(Clone)]
115pub struct LinkWiggleBlockInput {
116    pub design: DesignMatrix,
117    pub knots: Array1<f64>,
118    pub degree: usize,
119    pub penalties: Vec<gam_terms::penalty_spec::PenaltySpec>,
120    /// Structural nullspace dimension of each penalty matrix.
121    pub nullspace_dims: Vec<usize>,
122    pub initial_log_lambdas: Option<Array1<f64>>,
123    pub initial_beta: Option<Array1<f64>>,
124}
125
126#[derive(Clone)]
127pub struct TimeWiggleBlockInput {
128    pub knots: Array1<f64>,
129    pub degree: usize,
130    pub ncols: usize,
131}
132
133#[derive(Clone)]
134pub(crate) struct SurvivalLocationScaleSpec {
135    pub age_entry: Array1<f64>,
136    pub age_exit: Array1<f64>,
137    pub event_target: Array1<f64>,
138    pub weights: Array1<f64>,
139    pub inverse_link: InverseLink,
140    pub derivative_guard: f64,
141    pub max_iter: usize,
142    pub tol: f64,
143    pub time_block: TimeBlockInput,
144    pub threshold_block: CovariateBlockKind,
145    pub log_sigma_block: CovariateBlockKind,
146    pub timewiggle_block: Option<TimeWiggleBlockInput>,
147    pub linkwiggle_block: Option<LinkWiggleBlockInput>,
148    /// Explicit persistent warm-start cache session. See
149    /// [`BlockwiseFitOptions::cache_session`].
150    pub cache_session: Option<std::sync::Arc<gam_runtime::warm_start::Session>>,
151    /// Persistent warm-start mirror sessions; see
152    /// [`BlockwiseFitOptions::cache_mirror_sessions`].
153    pub cache_mirror_sessions: Vec<std::sync::Arc<gam_runtime::warm_start::Session>>,
154}
155
156#[derive(Clone)]
157pub enum SurvivalCovariateTermBlockTemplate {
158    Static,
159    TimeVarying {
160        time_basis_entry: Array2<f64>,
161        time_basis_exit: Array2<f64>,
162        time_basis_derivative_exit: Array2<f64>,
163        time_penalties: Vec<Array2<f64>>,
164    },
165}
166
167#[derive(Clone)]
168pub struct SurvivalLocationScaleTermSpec {
169    pub age_entry: Array1<f64>,
170    pub age_exit: Array1<f64>,
171    pub event_target: Array1<f64>,
172    pub weights: Array1<f64>,
173    pub inverse_link: InverseLink,
174    /// Strict lower bound on d_eta/dt used by both the event Jacobian term
175    /// and the time monotonicity constraints.
176    pub derivative_guard: f64,
177    pub max_iter: usize,
178    pub tol: f64,
179    pub time_block: TimeBlockInput,
180    pub thresholdspec: TermCollectionSpec,
181    pub log_sigmaspec: TermCollectionSpec,
182    pub threshold_offset: Array1<f64>,
183    pub log_sigma_offset: Array1<f64>,
184    pub threshold_template: SurvivalCovariateTermBlockTemplate,
185    pub log_sigma_template: SurvivalCovariateTermBlockTemplate,
186    pub timewiggle_block: Option<TimeWiggleBlockInput>,
187    pub linkwiggle_block: Option<LinkWiggleBlockInput>,
188    /// Optional warm-start seed for the threshold-block log-smoothing parameters (ρ).
189    /// When `Some`, its length must equal the number of threshold penalties; values are
190    /// clamped to the outer-loop ρ bounds before being injected into `rho0`.
191    /// Used by the outer baseline-config optimizer to thread converged smoothing
192    /// from one probe into the next.
193    pub initial_threshold_log_lambdas: Option<Array1<f64>>,
194    /// Optional warm-start seed for the log-sigma-block log-smoothing parameters (ρ).
195    /// Same semantics as `initial_threshold_log_lambdas`.
196    pub initial_log_sigma_log_lambdas: Option<Array1<f64>>,
197    /// Explicit persistent warm-start cache session. See
198    /// [`crate::custom_family::BlockwiseFitOptions::cache_session`].
199    pub cache_session: Option<std::sync::Arc<gam_runtime::warm_start::Session>>,
200    /// Explicit persistent warm-start mirror sessions. See
201    /// [`crate::custom_family::BlockwiseFitOptions::cache_mirror_sessions`].
202    pub cache_mirror_sessions: Vec<std::sync::Arc<gam_runtime::warm_start::Session>>,
203}
204
205pub const DEFAULT_SURVIVAL_LOCATION_SCALE_DERIVATIVE_GUARD: f64 = 1e-6;
206
207pub struct SurvivalLocationScaleTermFitResult {
208    pub fit: UnifiedFitResult,
209    pub resolved_thresholdspec: TermCollectionSpec,
210    pub resolved_log_sigmaspec: TermCollectionSpec,
211    pub threshold_design: TermCollectionDesign,
212    pub log_sigma_design: TermCollectionDesign,
213    /// Per-row gradient of unpenalized NLL w.r.t. the three additive time-block
214    /// offset channels (entry / exit / derivative-at-exit) at the converged β.
215    /// Contracted with `∂o/∂θ_baseline` this yields the analytic θ-gradient
216    /// used by the with-gradient baseline optimizer.
217    pub baseline_offset_residuals: OffsetChannelResiduals,
218    /// 3×3 NLL Hessian per row on the offset channels, in
219    /// `(entry, exit, derivative)` order. Diagonal under location-scale —
220    /// the row likelihood is separable in `(u0, u1, g)`. Used by the analytic
221    /// θ-Hessian builder (chain rule second derivative).
222    pub baseline_offset_curvatures: OffsetChannelCurvatures,
223    /// Exact data-fit gradient `∂(−ℓ)/∂θ_link` of the unpenalized
224    /// log-likelihood w.r.t. the inverse-link parameters at the converged β̂
225    /// (`None` when the inverse link carries no free parameters). Equals the
226    /// envelope-theorem θ_link-gradient of the profile penalized NLL, consumed
227    /// by the inverse-link BFGS optimizer.
228    pub link_param_data_fit_gradient: Option<Array1<f64>>,
229}
230
231/// Helper struct so callers can build a `UnifiedFitResult` from
232/// survival-specific fields without knowing about the unified layout.
233pub struct SurvivalLocationScaleFitResultParts {
234    pub beta_time: Array1<f64>,
235    pub beta_threshold: Array1<f64>,
236    pub beta_log_sigma: Array1<f64>,
237    pub beta_link_wiggle: Option<Array1<f64>>,
238    pub link_wiggle_knots: Option<Array1<f64>>,
239    pub link_wiggle_degree: Option<usize>,
240    pub lambdas_time: Array1<f64>,
241    pub lambdas_threshold: Array1<f64>,
242    pub lambdas_log_sigma: Array1<f64>,
243    pub lambdas_linkwiggle: Option<Array1<f64>>,
244    pub log_likelihood: f64,
245    pub reml_score: f64,
246    pub stable_penalty_term: f64,
247    pub penalized_objective: f64,
248    /// Whether any GPU device executed part of this fit (GPU-flag propagation).
249    /// Survival location-scale runs on the CPU path, so this is `false`; it is
250    /// carried so the assembled `UnifiedFitResultParts` reports a real value.
251    pub used_device: bool,
252    pub outer_iterations: usize,
253    /// `None` = no gradient measured at termination; `Some(g)` = measured.
254    /// `outer_converged` is the authoritative convergence signal.
255    pub outer_gradient_norm: Option<f64>,
256    pub outer_converged: bool,
257    pub covariance_conditional: Option<Array2<f64>>,
258    pub geometry: Option<FitGeometry>,
259}
260
261#[derive(Clone, Copy)]
262pub(crate) struct SurvivalLambdaLayout {
263    pub(crate) k_time: usize,
264    pub(crate) k_threshold: usize,
265    pub(crate) k_log_sigma: usize,
266    pub(crate) k_wiggle: usize,
267}
268
269impl SurvivalLambdaLayout {
270    pub(crate) fn new(
271        k_time: usize,
272        k_threshold: usize,
273        k_log_sigma: usize,
274        k_wiggle: usize,
275    ) -> Self {
276        Self {
277            k_time,
278            k_threshold,
279            k_log_sigma,
280            k_wiggle,
281        }
282    }
283
284    pub(crate) fn total(&self) -> usize {
285        self.k_time + self.k_threshold + self.k_log_sigma + self.k_wiggle
286    }
287
288    pub(crate) fn time_range(&self) -> std::ops::Range<usize> {
289        0..self.k_time
290    }
291
292    pub(crate) fn threshold_range(&self) -> std::ops::Range<usize> {
293        self.k_time..self.k_time + self.k_threshold
294    }
295
296    pub(crate) fn log_sigma_range(&self) -> std::ops::Range<usize> {
297        self.k_time + self.k_threshold..self.k_time + self.k_threshold + self.k_log_sigma
298    }
299
300    pub(crate) fn wiggle_range(&self) -> std::ops::Range<usize> {
301        self.k_time + self.k_threshold + self.k_log_sigma..self.total()
302    }
303
304    pub(crate) fn validate_rho(&self, rho: &Array1<f64>, label: &str) -> Result<(), String> {
305        if rho.len() != self.total() {
306            return Err(SurvivalLocationScaleError::DimensionMismatch {
307                reason: format!(
308                    "{label} rho length mismatch: got {}, expected {}",
309                    rho.len(),
310                    self.total()
311                ),
312            }
313            .into());
314        }
315        Ok::<(), _>(())
316    }
317
318    pub(crate) fn time_from(&self, rho: &Array1<f64>) -> Array1<f64> {
319        let range = self.time_range();
320        rho.slice(s![range.start..range.end]).to_owned()
321    }
322
323    pub(crate) fn threshold_from(&self, rho: &Array1<f64>) -> Array1<f64> {
324        let range = self.threshold_range();
325        rho.slice(s![range.start..range.end]).to_owned()
326    }
327
328    pub(crate) fn log_sigma_from(&self, rho: &Array1<f64>) -> Array1<f64> {
329        let range = self.log_sigma_range();
330        rho.slice(s![range.start..range.end]).to_owned()
331    }
332
333    pub(crate) fn wiggle_from(&self, rho: &Array1<f64>) -> Option<Array1<f64>> {
334        if self.k_wiggle == 0 {
335            None
336        } else {
337            let range = self.wiggle_range();
338            Some(rho.slice(s![range.start..range.end]).to_owned())
339        }
340    }
341}
342
343/// Build a `UnifiedFitResult` from survival-specific fields.
344pub fn survival_fit_from_parts(
345    parts: SurvivalLocationScaleFitResultParts,
346) -> Result<UnifiedFitResult, String> {
347    let SurvivalLocationScaleFitResultParts {
348        beta_time,
349        beta_threshold,
350        beta_log_sigma,
351        beta_link_wiggle,
352        link_wiggle_knots,
353        link_wiggle_degree,
354        lambdas_time,
355        lambdas_threshold,
356        lambdas_log_sigma,
357        lambdas_linkwiggle,
358        log_likelihood,
359        reml_score,
360        stable_penalty_term,
361        penalized_objective,
362        used_device,
363        outer_iterations,
364        outer_gradient_norm,
365        outer_converged,
366        covariance_conditional,
367        geometry,
368    } = parts;
369
370    // Validation (preserved from the old impl).
371    validate_all_finite_estimation("survival_fit.beta_time", beta_time.iter().copied())
372        .map_err(|e| e.to_string())?;
373    validate_all_finite_estimation(
374        "survival_fit.beta_threshold",
375        beta_threshold.iter().copied(),
376    )
377    .map_err(|e| e.to_string())?;
378    validate_all_finite_estimation(
379        "survival_fit.beta_log_sigma",
380        beta_log_sigma.iter().copied(),
381    )
382    .map_err(|e| e.to_string())?;
383    if let Some(beta_wiggle) = beta_link_wiggle.as_ref() {
384        validate_all_finite_estimation(
385            "survival_fit.beta_link_wiggle",
386            beta_wiggle.iter().copied(),
387        )
388        .map_err(|e| e.to_string())?;
389        let knots = link_wiggle_knots.as_ref().ok_or_else(|| {
390            "survival_fit.beta_link_wiggle requires link_wiggle_knots".to_string()
391        })?;
392        validate_all_finite_estimation("survival_fit.link_wiggle_knots", knots.iter().copied())
393            .map_err(|e| e.to_string())?;
394        if link_wiggle_degree.is_none() {
395            return Err(SurvivalLocationScaleError::InvalidConfiguration {
396                reason: "survival_fit.beta_link_wiggle requires link_wiggle_degree".to_string(),
397            }
398            .into());
399        }
400    } else if link_wiggle_knots.is_some() || link_wiggle_degree.is_some() {
401        return Err(SurvivalLocationScaleError::InvalidConfiguration {
402            reason: "survival_fit link-wiggle metadata requires beta_link_wiggle coefficients"
403                .to_string(),
404        }
405        .into());
406    }
407    validate_all_finite_estimation("survival_fit.lambdas_time", lambdas_time.iter().copied())
408        .map_err(|e| e.to_string())?;
409    validate_all_finite_estimation(
410        "survival_fit.lambdas_threshold",
411        lambdas_threshold.iter().copied(),
412    )
413    .map_err(|e| e.to_string())?;
414    validate_all_finite_estimation(
415        "survival_fit.lambdas_log_sigma",
416        lambdas_log_sigma.iter().copied(),
417    )
418    .map_err(|e| e.to_string())?;
419    // Each block's smoothing-parameter count counts the number of distinct
420    // penalty terms acting on that block's coefficients. A penalty term cannot
421    // outnumber the coefficients it penalizes, so reject `lambdas_<block>`
422    // vectors longer than the corresponding `beta_<block>`. This catches stale
423    // / misaligned lambda slices that would otherwise propagate silently into
424    // downstream inference where the per-block penalty bookkeeping is
425    // unrecoverable.
426    if lambdas_time.len() > beta_time.len() {
427        return Err(SurvivalLocationScaleError::DimensionMismatch {
428            reason: format!(
429                "survival_fit.lambdas_time has {} entries but beta_time has only {} \
430                 coefficients; each lambda corresponds to a penalty term on this block",
431                lambdas_time.len(),
432                beta_time.len()
433            ),
434        }
435        .into());
436    }
437    if lambdas_threshold.len() > beta_threshold.len() {
438        return Err(SurvivalLocationScaleError::DimensionMismatch {
439            reason: format!(
440                "survival_fit.lambdas_threshold has {} entries but beta_threshold has only {} \
441                 coefficients; each lambda corresponds to a penalty term on this block",
442                lambdas_threshold.len(),
443                beta_threshold.len()
444            ),
445        }
446        .into());
447    }
448    if lambdas_log_sigma.len() > beta_log_sigma.len() {
449        return Err(SurvivalLocationScaleError::DimensionMismatch {
450            reason: format!(
451                "survival_fit.lambdas_log_sigma has {} entries but beta_log_sigma has only {} \
452                 coefficients; each lambda corresponds to a penalty term on this block",
453                lambdas_log_sigma.len(),
454                beta_log_sigma.len()
455            ),
456        }
457        .into());
458    }
459    if let Some(lambdas_wiggle) = lambdas_linkwiggle.as_ref() {
460        if beta_link_wiggle.is_none() {
461            return Err(SurvivalLocationScaleError::InvalidConfiguration {
462                reason: "survival_fit.lambdas_linkwiggle requires beta_link_wiggle".to_string(),
463            }
464            .into());
465        }
466        validate_all_finite_estimation(
467            "survival_fit.lambdas_linkwiggle",
468            lambdas_wiggle.iter().copied(),
469        )
470        .map_err(|e| e.to_string())?;
471        let wiggle_len = beta_link_wiggle.as_ref().map_or(0, |beta| beta.len());
472        if lambdas_wiggle.len() > wiggle_len {
473            return Err(SurvivalLocationScaleError::DimensionMismatch {
474                reason: format!(
475                    "survival_fit.lambdas_linkwiggle has {} entries but beta_link_wiggle has \
476                     only {} coefficients; each lambda corresponds to a penalty term on this block",
477                    lambdas_wiggle.len(),
478                    wiggle_len
479                ),
480            }
481            .into());
482        }
483    }
484    ensure_finite_scalar_estimation("survival_fit.log_likelihood", log_likelihood)
485        .map_err(|e| e.to_string())?;
486    ensure_finite_scalar_estimation("survival_fit.reml_score", reml_score)
487        .map_err(|e| e.to_string())?;
488    ensure_finite_scalar_estimation("survival_fit.stable_penalty_term", stable_penalty_term)
489        .map_err(|e| e.to_string())?;
490    ensure_finite_scalar_estimation("survival_fit.penalized_objective", penalized_objective)
491        .map_err(|e| e.to_string())?;
492    if let Some(g) = outer_gradient_norm {
493        ensure_finite_scalar_estimation("survival_fit.outer_gradient_norm", g)
494            .map_err(|e| e.to_string())?;
495    }
496
497    let total_p = beta_time.len()
498        + beta_threshold.len()
499        + beta_log_sigma.len()
500        + beta_link_wiggle.as_ref().map_or(0, |beta| beta.len());
501    if let Some(cov) = covariance_conditional.as_ref() {
502        validate_all_finite_estimation("survival_fit.covariance_conditional", cov.iter().copied())
503            .map_err(|e| e.to_string())?;
504        let (rows, cols) = cov.dim();
505        if rows != total_p || cols != total_p {
506            return Err(SurvivalLocationScaleError::InvalidConfiguration {
507                reason: format!(
508                    "survival_fit.covariance_conditional must be {}x{}, got {}x{}",
509                    total_p, total_p, rows, cols
510                ),
511            }
512            .into());
513        }
514    }
515    if let Some(geom) = geometry.as_ref() {
516        geom.validate_numeric_finiteness()
517            .map_err(|e| e.to_string())?;
518        let (rows, cols) = geom.penalized_hessian.dim();
519        if rows != total_p || cols != total_p {
520            return Err(SurvivalLocationScaleError::InvalidConfiguration {
521                reason: format!(
522                    "survival_fit.geometry.penalized_hessian must be {}x{}, got {}x{}",
523                    total_p, total_p, rows, cols
524                ),
525            }
526            .into());
527        }
528        if geom.working_weights.len() != geom.working_response.len() {
529            return Err(SurvivalLocationScaleError::DimensionMismatch {
530                reason: format!(
531                    "survival_fit.geometry working length mismatch: weights={}, response={}",
532                    geom.working_weights.len(),
533                    geom.working_response.len()
534                ),
535            }
536            .into());
537        }
538    }
539
540    // Build blocks for the unified representation.
541    use crate::model_types::{BlockRole, FittedBlock, FittedLinkState, UnifiedFitResultParts};
542    let mut blocks = vec![
543        FittedBlock {
544            beta: beta_time.clone(),
545            role: BlockRole::Time,
546            edf: 0.0,
547            lambdas: lambdas_time.clone(),
548        },
549        FittedBlock {
550            beta: beta_threshold.clone(),
551            role: BlockRole::Threshold,
552            edf: 0.0,
553            lambdas: lambdas_threshold.clone(),
554        },
555        FittedBlock {
556            beta: beta_log_sigma.clone(),
557            role: BlockRole::Scale,
558            edf: 0.0,
559            lambdas: lambdas_log_sigma.clone(),
560        },
561    ];
562    if let Some(ref bw) = beta_link_wiggle {
563        blocks.push(FittedBlock {
564            beta: bw.clone(),
565            role: BlockRole::LinkWiggle,
566            edf: 0.0,
567            lambdas: lambdas_linkwiggle
568                .clone()
569                .unwrap_or_else(|| Array1::zeros(0)),
570        });
571    }
572    let all_lambdas: Vec<f64> = blocks
573        .iter()
574        .flat_map(|b| b.lambdas.iter().copied())
575        .collect();
576    let log_lambdas = Array1::from_vec(
577        all_lambdas
578            .iter()
579            .map(|&v| if v > 0.0 { v.ln() } else { f64::NEG_INFINITY })
580            .collect(),
581    );
582    let deviance = -2.0 * log_likelihood;
583    crate::model_types::UnifiedFitResult::try_from_parts(UnifiedFitResultParts {
584        blocks,
585        log_lambdas,
586        lambdas: Array1::from_vec(all_lambdas),
587        likelihood_family: None,
588        likelihood_scale: gam_problem::LikelihoodScaleMetadata::Unspecified,
589        log_likelihood_normalization: gam_problem::LogLikelihoodNormalization::UserProvided,
590        log_likelihood,
591        deviance,
592        reml_score,
593        stable_penalty_term,
594        penalized_objective,
595        used_device,
596        outer_iterations,
597        outer_converged,
598        outer_gradient_norm,
599        standard_deviation: 1.0,
600        covariance_conditional,
601        covariance_corrected: None,
602        inference: None,
603        fitted_link: FittedLinkState::Standard(None),
604        geometry,
605        block_states: Vec::new(),
606        pirls_status: gam_solve::pirls::PirlsStatus::Converged,
607        max_abs_eta: 0.0,
608        constraint_kkt: None,
609        artifacts: crate::model_types::FitArtifacts {
610            pirls: None,
611            null_space_logdet: None,
612            null_space_dim: None,
613            survival_link_wiggle_knots: link_wiggle_knots,
614            survival_link_wiggle_degree: link_wiggle_degree,
615            criterion_certificate: None,
616            rho_posterior_certificate: None,
617            rho_posterior_escalation: None,
618            rho_covariance: None,
619        },
620        inner_cycles: 0,
621    })
622    .map_err(|e| e.to_string())
623}
624
625#[derive(Clone)]
626pub struct SurvivalLocationScalePredictInput {
627    pub x_time_exit: Array2<f64>,
628    pub eta_time_offset_exit: Array1<f64>,
629    pub time_wiggle_knots: Option<Array1<f64>>,
630    pub time_wiggle_degree: Option<usize>,
631    pub time_wiggle_ncols: usize,
632    pub x_threshold: DesignMatrix,
633    pub eta_threshold_offset: Array1<f64>,
634    pub x_log_sigma: DesignMatrix,
635    pub eta_log_sigma_offset: Array1<f64>,
636    pub x_link_wiggle: Option<DesignMatrix>,
637    pub link_wiggle_knots: Option<Array1<f64>>,
638    pub link_wiggle_degree: Option<usize>,
639    pub inverse_link: InverseLink,
640}
641
642#[derive(Clone, Debug)]
643pub struct SurvivalLocationScalePredictResult {
644    pub eta: Array1<f64>,
645    pub survival_prob: Array1<f64>,
646}
647
648#[derive(Clone)]
649pub struct SurvivalLocationScalePredictUncertaintyResult {
650    pub eta: Array1<f64>,
651    pub survival_prob: Array1<f64>,
652    pub eta_standard_error: Array1<f64>,
653    pub response_standard_error: Option<Array1<f64>>,
654}
655
656pub(crate) fn initial_log_lambdas<T>(
657    penalties: &[T],
658    rho0: Option<Array1<f64>>,
659) -> Result<Array1<f64>, String> {
660    let k = penalties.len();
661    let rho = rho0.unwrap_or_else(|| Array1::zeros(k));
662    if rho.len() != k {
663        return Err(SurvivalLocationScaleError::DimensionMismatch {
664            reason: format!(
665                "initial_log_lambdas mismatch: got {}, expected {k}",
666                rho.len()
667            ),
668        }
669        .into());
670    }
671    Ok(rho)
672}