Skip to main content

gam/solver/
workflow.rs

1use crate::custom_family::BlockwiseFitOptions;
2use crate::estimate::{EstimationError, FitOptions, FittedLinkState, UnifiedFitResult};
3use crate::families::bernoulli_marginal_slope::{
4    BernoulliMarginalSlopeFitResult, BernoulliMarginalSlopeTermSpec, DeviationBlockConfig,
5    fit_bernoulli_marginal_slope_terms,
6};
7use crate::families::gamlss::{
8    BinomialLocationScaleFitResult, BinomialLocationScaleTermSpec, BlockwiseTermFitResult,
9    BlockwiseTermFitResultParts, GaussianLocationScaleFitResult, GaussianLocationScaleTermSpec,
10    WiggleBlockConfig, fit_binomial_location_scale_terms,
11    fit_binomial_location_scale_terms_with_selected_wiggle,
12    fit_binomial_mean_wiggle_terms_with_selected_basis, fit_gaussian_location_scale_terms,
13    fit_gaussian_location_scale_terms_with_selected_wiggle,
14    select_binomial_location_scale_link_wiggle_basis_from_pilot,
15    select_binomial_mean_link_wiggle_basis_from_pilot,
16    select_gaussian_location_scale_link_wiggle_basis_from_pilot,
17};
18use crate::families::latent_survival::{
19    LatentBinaryTermFitResult, LatentBinaryTermSpec, LatentSurvivalTermFitResult,
20    LatentSurvivalTermSpec, fit_latent_binary_terms, fit_latent_survival_terms,
21    latent_hazard_loading,
22};
23use crate::families::lognormal_kernel::FrailtySpec;
24use crate::families::survival_location_scale::{
25    SurvivalLocationScaleTermFitResult, SurvivalLocationScaleTermSpec,
26    fit_survival_location_scale_terms, fit_survival_location_scale_terms_with_selected_wiggle,
27    select_survival_link_wiggle_basis_from_pilot,
28};
29use crate::families::survival_marginal_slope::{
30    SurvivalMarginalSlopeFitResult, SurvivalMarginalSlopeTermSpec,
31    fit_survival_marginal_slope_terms,
32};
33use crate::families::transformation_normal::{
34    TransformationNormalConfig, TransformationNormalFitResult, TransformationWarmStart,
35    fit_transformation_normal,
36};
37use crate::mixture_link::{state_from_beta_logisticspec, state_from_sasspec, state_fromspec};
38use crate::smooth::{
39    AdaptiveRegularizationDiagnostics, SpatialLengthScaleOptimizationOptions, TermCollectionDesign,
40    TermCollectionSpec, build_term_collection_design,
41    fit_term_collectionwith_spatial_length_scale_optimization,
42};
43use crate::types::{
44    InverseLink, LatentCLogLogState, LikelihoodFamily, LinkFunction, MixtureLinkSpec, SasLinkSpec,
45    WigglePenaltyConfig,
46};
47use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
48use std::collections::HashMap;
49
50#[derive(Clone, Debug)]
51pub struct LinkWiggleConfig {
52    pub degree: usize,
53    pub num_internal_knots: usize,
54    pub penalty_orders: Vec<usize>,
55    pub double_penalty: bool,
56}
57
58#[derive(Clone, Debug)]
59pub struct StandardBinomialWiggleConfig {
60    pub link_kind: InverseLink,
61    pub wiggle: LinkWiggleConfig,
62}
63
64pub struct StandardFitRequest<'a> {
65    pub data: ArrayView2<'a, f64>,
66    pub y: Array1<f64>,
67    pub weights: Array1<f64>,
68    pub offset: Array1<f64>,
69    pub spec: TermCollectionSpec,
70    pub family: LikelihoodFamily,
71    pub options: FitOptions,
72    pub kappa_options: SpatialLengthScaleOptimizationOptions,
73    pub wiggle: Option<StandardBinomialWiggleConfig>,
74    pub wiggle_options: Option<BlockwiseFitOptions>,
75}
76
77pub struct GaussianLocationScaleFitRequest<'a> {
78    pub data: ArrayView2<'a, f64>,
79    pub spec: GaussianLocationScaleTermSpec,
80    pub wiggle: Option<LinkWiggleConfig>,
81    pub options: BlockwiseFitOptions,
82    pub kappa_options: SpatialLengthScaleOptimizationOptions,
83}
84
85pub struct BinomialLocationScaleFitRequest<'a> {
86    pub data: ArrayView2<'a, f64>,
87    pub spec: BinomialLocationScaleTermSpec,
88    pub wiggle: Option<LinkWiggleConfig>,
89    pub options: BlockwiseFitOptions,
90    pub kappa_options: SpatialLengthScaleOptimizationOptions,
91}
92
93pub struct SurvivalLocationScaleFitRequest<'a> {
94    pub data: ArrayView2<'a, f64>,
95    pub spec: SurvivalLocationScaleTermSpec,
96    pub wiggle: Option<LinkWiggleConfig>,
97    pub kappa_options: SpatialLengthScaleOptimizationOptions,
98    pub optimize_inverse_link: bool,
99}
100
101pub struct SurvivalTransformationFitRequest<'a> {
102    pub data: ArrayView2<'a, f64>,
103    pub spec: SurvivalTransformationTermSpec,
104}
105
106#[derive(Clone)]
107pub struct SurvivalTransformationTermSpec {
108    pub age_entry: Array1<f64>,
109    pub age_exit: Array1<f64>,
110    pub event_target: Array1<u8>,
111    pub weights: Array1<f64>,
112    pub covariate_spec: TermCollectionSpec,
113    pub covariate_offset: Array1<f64>,
114    pub baseline_cfg: crate::families::survival_construction::SurvivalBaselineConfig,
115    pub likelihood_mode: crate::families::survival_construction::SurvivalLikelihoodMode,
116    pub time_anchor: f64,
117    pub time_build: crate::families::survival_construction::SurvivalTimeBuildOutput,
118    pub timewiggle: Option<LinkWiggleFormulaSpec>,
119    pub weibull_seed: Option<(f64, f64)>,
120    pub ridge_lambda: f64,
121}
122
123pub(crate) fn survival_inverse_link_has_free_parameters(link: &InverseLink) -> bool {
124    match link {
125        InverseLink::Sas(_) | InverseLink::BetaLogistic(_) => true,
126        InverseLink::Mixture(state) => !state.rho.is_empty(),
127        InverseLink::LatentCLogLog(_) | InverseLink::Standard(_) => false,
128    }
129}
130
131fn recover_converged_survival_inverse_link<R>(
132    result: crate::solver::outer_strategy::OuterResult,
133    context: &str,
134    recover: R,
135) -> Result<InverseLink, String>
136where
137    R: FnOnce(&Array1<f64>) -> Option<InverseLink>,
138{
139    if !result.converged {
140        return Err(format!(
141            "{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={:.3e})",
142            result.iterations, result.final_value, result.final_grad_norm
143        ));
144    }
145    recover(&result.rho).ok_or_else(|| {
146        format!(
147            "{context} produced an invalid inverse-link state at rho={:?}",
148            result.rho.to_vec()
149        )
150    })
151}
152
153pub struct BernoulliMarginalSlopeFitRequest<'a> {
154    pub data: ArrayView2<'a, f64>,
155    pub spec: BernoulliMarginalSlopeTermSpec,
156    pub options: BlockwiseFitOptions,
157    pub kappa_options: SpatialLengthScaleOptimizationOptions,
158    pub policy: crate::resource::ResourcePolicy,
159}
160
161pub struct SurvivalMarginalSlopeFitRequest<'a> {
162    pub data: ArrayView2<'a, f64>,
163    pub spec: SurvivalMarginalSlopeTermSpec,
164    pub options: BlockwiseFitOptions,
165    pub kappa_options: SpatialLengthScaleOptimizationOptions,
166}
167
168pub struct LatentSurvivalFitRequest<'a> {
169    pub data: ArrayView2<'a, f64>,
170    pub spec: LatentSurvivalTermSpec,
171    pub frailty: FrailtySpec,
172    pub options: BlockwiseFitOptions,
173}
174
175pub struct LatentBinaryFitRequest<'a> {
176    pub data: ArrayView2<'a, f64>,
177    pub spec: LatentBinaryTermSpec,
178    pub frailty: FrailtySpec,
179    pub options: BlockwiseFitOptions,
180}
181
182pub struct TransformationNormalFitRequest<'a> {
183    pub data: ArrayView2<'a, f64>,
184    pub response: Array1<f64>,
185    pub weights: Array1<f64>,
186    pub offset: Array1<f64>,
187    pub covariate_spec: TermCollectionSpec,
188    pub config: TransformationNormalConfig,
189    pub options: BlockwiseFitOptions,
190    pub kappa_options: SpatialLengthScaleOptimizationOptions,
191    pub warm_start: Option<TransformationWarmStart>,
192}
193
194pub enum FitRequest<'a> {
195    Standard(StandardFitRequest<'a>),
196    GaussianLocationScale(GaussianLocationScaleFitRequest<'a>),
197    BinomialLocationScale(BinomialLocationScaleFitRequest<'a>),
198    SurvivalLocationScale(SurvivalLocationScaleFitRequest<'a>),
199    SurvivalTransformation(SurvivalTransformationFitRequest<'a>),
200    BernoulliMarginalSlope(BernoulliMarginalSlopeFitRequest<'a>),
201    SurvivalMarginalSlope(SurvivalMarginalSlopeFitRequest<'a>),
202    LatentSurvival(LatentSurvivalFitRequest<'a>),
203    LatentBinary(LatentBinaryFitRequest<'a>),
204    TransformationNormal(TransformationNormalFitRequest<'a>),
205}
206
207pub struct StandardFitResult {
208    pub fit: UnifiedFitResult,
209    pub design: TermCollectionDesign,
210    pub resolvedspec: TermCollectionSpec,
211    pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
212    pub saved_link_state: FittedLinkState,
213    pub wiggle_knots: Option<Array1<f64>>,
214    pub wiggle_degree: Option<usize>,
215}
216
217pub struct SurvivalLocationScaleFitResult {
218    pub fit: SurvivalLocationScaleTermFitResult,
219    pub inverse_link: InverseLink,
220    pub wiggle_knots: Option<Array1<f64>>,
221    pub wiggle_degree: Option<usize>,
222}
223
224pub struct SurvivalTransformationFitResult {
225    pub fit: UnifiedFitResult,
226    pub resolvedspec: TermCollectionSpec,
227    pub baseline_cfg: crate::families::survival_construction::SurvivalBaselineConfig,
228    pub likelihood_mode: crate::families::survival_construction::SurvivalLikelihoodMode,
229    pub time_anchor: f64,
230    pub time_basisname: String,
231    pub time_base_ncols: usize,
232    pub time_degree: Option<usize>,
233    pub time_knots: Option<Vec<f64>>,
234    pub time_keep_cols: Option<Vec<usize>>,
235    pub time_smooth_lambda: Option<f64>,
236    pub baseline_timewiggle: Option<TimeWiggleBlockInput>,
237}
238
239struct SurvivalLocationScaleProfile {
240    fit: SurvivalLocationScaleTermFitResult,
241    inverse_link: InverseLink,
242    wiggle_knots: Option<Array1<f64>>,
243    wiggle_degree: Option<usize>,
244}
245
246impl SurvivalLocationScaleProfile {
247    fn objective(&self) -> f64 {
248        self.fit.fit.reml_score
249    }
250
251    fn into_result(self) -> SurvivalLocationScaleFitResult {
252        SurvivalLocationScaleFitResult {
253            fit: self.fit,
254            inverse_link: self.inverse_link,
255            wiggle_knots: self.wiggle_knots,
256            wiggle_degree: self.wiggle_degree,
257        }
258    }
259}
260
261pub enum FitResult {
262    Standard(StandardFitResult),
263    GaussianLocationScale(GaussianLocationScaleFitResult),
264    BinomialLocationScale(BinomialLocationScaleFitResult),
265    SurvivalLocationScale(SurvivalLocationScaleFitResult),
266    SurvivalTransformation(SurvivalTransformationFitResult),
267    BernoulliMarginalSlope(BernoulliMarginalSlopeFitResult),
268    SurvivalMarginalSlope(SurvivalMarginalSlopeFitResult),
269    LatentSurvival(LatentSurvivalTermFitResult),
270    LatentBinary(LatentBinaryTermFitResult),
271    TransformationNormal(TransformationNormalFitResult),
272}
273
274fn resolved_wiggle_inverse_link(
275    family: LikelihoodFamily,
276    fit: &UnifiedFitResult,
277    fallback: &InverseLink,
278) -> Result<InverseLink, String> {
279    let resolved = match fit.fitted_link_state(family).map_err(|e| e.to_string())? {
280        FittedLinkState::Standard(Some(link)) => InverseLink::Standard(link),
281        FittedLinkState::Standard(None) => fallback.clone(),
282        FittedLinkState::LatentCLogLog { state } => InverseLink::LatentCLogLog(state),
283        FittedLinkState::Sas { state, .. } => InverseLink::Sas(state),
284        FittedLinkState::BetaLogistic { state, .. } => InverseLink::BetaLogistic(state),
285        FittedLinkState::Mixture { state, .. } => InverseLink::Mixture(state),
286    };
287    require_inverse_link_supports_joint_wiggle(&resolved, "standard link wiggle")?;
288    Ok(resolved)
289}
290
291fn deviation_block_config_from_formula_linkwiggle(
292    wiggle: &LinkWiggleFormulaSpec,
293) -> DeviationBlockConfig {
294    let defaults = WigglePenaltyConfig::cubic_triple_operator_default();
295    DeviationBlockConfig {
296        degree: wiggle.degree,
297        num_internal_knots: wiggle.num_internal_knots,
298        penalty_order: *wiggle.penalty_orders.iter().max().unwrap_or(&2),
299        penalty_orders: wiggle.penalty_orders.clone(),
300        double_penalty: wiggle.double_penalty,
301        monotonicity_eps: defaults.monotonicity_eps,
302    }
303}
304
305struct MarginalSlopeDeviationRouting {
306    score_warp: Option<DeviationBlockConfig>,
307    link_dev: Option<DeviationBlockConfig>,
308}
309
310fn route_marginal_slope_deviation_blocks(
311    main_linkwiggle: Option<&LinkWiggleFormulaSpec>,
312    logslope_linkwiggle: Option<&LinkWiggleFormulaSpec>,
313) -> Result<MarginalSlopeDeviationRouting, String> {
314    Ok(MarginalSlopeDeviationRouting {
315        score_warp: logslope_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
316        link_dev: main_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
317    })
318}
319
320fn fixed_gaussian_shift_frailty_from_spec(
321    frailty: &FrailtySpec,
322    context: &str,
323) -> Result<FrailtySpec, String> {
324    match frailty {
325        FrailtySpec::None => Ok(FrailtySpec::None),
326        FrailtySpec::GaussianShift {
327            sigma_fixed: Some(sigma),
328        } => Ok(FrailtySpec::GaussianShift {
329            sigma_fixed: Some(*sigma),
330        }),
331        FrailtySpec::GaussianShift { sigma_fixed: None } => Err(format!(
332            "{context} currently requires a fixed GaussianShift sigma"
333        )),
334        FrailtySpec::HazardMultiplier { .. } => Err(format!(
335            "{context} requires FrailtySpec::GaussianShift or no frailty"
336        )),
337    }
338}
339
340fn fit_standard_model(request: StandardFitRequest<'_>) -> Result<StandardFitResult, String> {
341    let fitted = fit_term_collectionwith_spatial_length_scale_optimization(
342        request.data,
343        request.y.clone(),
344        request.weights.clone(),
345        request.offset.clone(),
346        &request.spec,
347        request.family,
348        &request.options,
349        &request.kappa_options,
350    )
351    .map_err(|e| e.to_string())?;
352
353    let result = StandardFitResult {
354        saved_link_state: fitted.fit.fitted_link.clone(),
355        fit: fitted.fit,
356        design: fitted.design,
357        resolvedspec: fitted.resolvedspec,
358        adaptive_diagnostics: fitted.adaptive_diagnostics,
359        wiggle_knots: None,
360        wiggle_degree: None,
361    };
362
363    let Some(wiggle) = request.wiggle else {
364        return Ok(result);
365    };
366    let wiggle_options = request
367        .wiggle_options
368        .ok_or_else(|| "standard wiggle workflow requires blockwise wiggle options".to_string())?;
369    let wiggle_link_kind =
370        resolved_wiggle_inverse_link(request.family, &result.fit, &wiggle.link_kind)?;
371    let selected_wiggle_basis = select_binomial_mean_link_wiggle_basis_from_pilot(
372        &result.design,
373        &result.fit,
374        &WiggleBlockConfig {
375            degree: wiggle.wiggle.degree,
376            num_internal_knots: wiggle.wiggle.num_internal_knots,
377            penalty_order: 2,
378            double_penalty: wiggle.wiggle.double_penalty,
379        },
380        &wiggle.wiggle.penalty_orders,
381    )?;
382
383    let solved = fit_binomial_mean_wiggle_terms_with_selected_basis(
384        request.data,
385        &result.resolvedspec,
386        &result.design,
387        &result.fit,
388        &request.y,
389        &request.weights,
390        wiggle_link_kind,
391        selected_wiggle_basis,
392        &wiggle_options,
393        &request.kappa_options,
394    )?;
395
396    Ok(StandardFitResult {
397        saved_link_state: result.saved_link_state,
398        fit: solved.fit,
399        design: solved.design,
400        resolvedspec: solved.resolvedspec,
401        adaptive_diagnostics: result.adaptive_diagnostics,
402        wiggle_knots: Some(solved.wiggle_knots),
403        wiggle_degree: Some(solved.wiggle_degree),
404    })
405}
406
407fn fit_gaussian_location_scale_model(
408    request: GaussianLocationScaleFitRequest<'_>,
409) -> Result<GaussianLocationScaleFitResult, String> {
410    if let Some(wiggle_cfg) = request.wiggle {
411        let pilot = fit_gaussian_location_scale_terms(
412            request.data,
413            GaussianLocationScaleTermSpec {
414                y: request.spec.y.clone(),
415                weights: request.spec.weights.clone(),
416                meanspec: request.spec.meanspec.clone(),
417                log_sigmaspec: request.spec.log_sigmaspec.clone(),
418                mean_offset: request.spec.mean_offset.clone(),
419                log_sigma_offset: request.spec.log_sigma_offset.clone(),
420            },
421            &request.options,
422            &request.kappa_options,
423        )?;
424        let selected_wiggle_basis = select_gaussian_location_scale_link_wiggle_basis_from_pilot(
425            &pilot,
426            &WiggleBlockConfig {
427                degree: wiggle_cfg.degree,
428                num_internal_knots: wiggle_cfg.num_internal_knots,
429                penalty_order: 2,
430                double_penalty: wiggle_cfg.double_penalty,
431            },
432            &wiggle_cfg.penalty_orders,
433        )?;
434        let solved = fit_gaussian_location_scale_terms_with_selected_wiggle(
435            request.data,
436            request.spec,
437            selected_wiggle_basis,
438            &request.options,
439            &request.kappa_options,
440        )?;
441        let fit = solved.fit.fit;
442        let beta_link_wiggle = fit.block_states.get(2).map(|b| b.beta.to_vec());
443        Ok(GaussianLocationScaleFitResult {
444            fit: BlockwiseTermFitResult::try_from_parts(BlockwiseTermFitResultParts {
445                fit,
446                meanspec_resolved: solved.fit.meanspec_resolved,
447                noisespec_resolved: solved.fit.noisespec_resolved,
448                mean_design: solved.fit.mean_design,
449                noise_design: solved.fit.noise_design,
450            })?,
451            wiggle_knots: Some(solved.wiggle_knots),
452            wiggle_degree: Some(solved.wiggle_degree),
453            beta_link_wiggle,
454        })
455    } else {
456        let fit = fit_gaussian_location_scale_terms(
457            request.data,
458            request.spec,
459            &request.options,
460            &request.kappa_options,
461        )?;
462        Ok(GaussianLocationScaleFitResult {
463            fit,
464            wiggle_knots: None,
465            wiggle_degree: None,
466            beta_link_wiggle: None,
467        })
468    }
469}
470
471fn fit_binomial_location_scale_model(
472    request: BinomialLocationScaleFitRequest<'_>,
473) -> Result<BinomialLocationScaleFitResult, String> {
474    if let Some(wiggle_cfg) = request.wiggle {
475        require_inverse_link_supports_joint_wiggle(
476            &request.spec.link_kind,
477            "binomial location-scale link wiggle",
478        )?;
479        let pilot = fit_binomial_location_scale_terms(
480            request.data,
481            BinomialLocationScaleTermSpec {
482                y: request.spec.y.clone(),
483                weights: request.spec.weights.clone(),
484                link_kind: request.spec.link_kind.clone(),
485                thresholdspec: request.spec.thresholdspec.clone(),
486                log_sigmaspec: request.spec.log_sigmaspec.clone(),
487                threshold_offset: request.spec.threshold_offset.clone(),
488                log_sigma_offset: request.spec.log_sigma_offset.clone(),
489            },
490            &request.options,
491            &request.kappa_options,
492        )?;
493        let selected_wiggle_basis = select_binomial_location_scale_link_wiggle_basis_from_pilot(
494            &pilot,
495            &WiggleBlockConfig {
496                degree: wiggle_cfg.degree,
497                num_internal_knots: wiggle_cfg.num_internal_knots,
498                penalty_order: 2,
499                double_penalty: wiggle_cfg.double_penalty,
500            },
501            &wiggle_cfg.penalty_orders,
502        )?;
503        let solved = fit_binomial_location_scale_terms_with_selected_wiggle(
504            request.data,
505            request.spec,
506            selected_wiggle_basis,
507            &request.options,
508            &request.kappa_options,
509        )?;
510        let fit = solved.fit.fit;
511        let beta_link_wiggle = fit.block_states.get(2).map(|b| b.beta.to_vec());
512        Ok(BinomialLocationScaleFitResult {
513            fit: BlockwiseTermFitResult::try_from_parts(BlockwiseTermFitResultParts {
514                fit,
515                meanspec_resolved: solved.fit.meanspec_resolved,
516                noisespec_resolved: solved.fit.noisespec_resolved,
517                mean_design: solved.fit.mean_design,
518                noise_design: solved.fit.noise_design,
519            })?,
520            wiggle_knots: Some(solved.wiggle_knots),
521            wiggle_degree: Some(solved.wiggle_degree),
522            beta_link_wiggle,
523        })
524    } else {
525        let solved = fit_binomial_location_scale_terms(
526            request.data,
527            request.spec,
528            &request.options,
529            &request.kappa_options,
530        )?;
531        Ok(BinomialLocationScaleFitResult {
532            fit: solved,
533            wiggle_knots: None,
534            wiggle_degree: None,
535            beta_link_wiggle: None,
536        })
537    }
538}
539
540fn survival_working_reml_score(state: &crate::pirls::WorkingState) -> f64 {
541    0.5 * (state.deviance + state.penalty_term)
542}
543
544fn fitted_weibull_baseline_from_linear_time_beta(
545    beta: &Array1<f64>,
546) -> Option<crate::families::survival_construction::SurvivalBaselineConfig> {
547    if beta.len() < 2 {
548        return None;
549    }
550    let shape = beta[1];
551    if !shape.is_finite() || shape <= 0.0 {
552        return None;
553    }
554    let scale = (-beta[0] / shape).exp();
555    if !scale.is_finite() || scale <= 0.0 {
556        return None;
557    }
558    Some(
559        crate::families::survival_construction::SurvivalBaselineConfig {
560            target: SurvivalBaselineTarget::Weibull,
561            scale: Some(scale),
562            shape: Some(shape),
563            rate: None,
564            makeham: None,
565        },
566    )
567}
568
569fn survival_unified_fit_result(
570    beta: Array1<f64>,
571    lambdas: Array1<f64>,
572    summary: &crate::pirls::WorkingModelPirlsResult,
573    state: &crate::pirls::WorkingState,
574) -> Result<UnifiedFitResult, String> {
575    let log_lambdas = lambdas.mapv(|v| v.max(1e-300).ln());
576    let reml_score = survival_working_reml_score(state);
577    crate::estimate::validate_all_finite("survival fit beta", beta.iter().copied())?;
578    crate::estimate::validate_all_finite("survival fit lambdas", lambdas.iter().copied())?;
579    crate::estimate::ensure_finite_scalar("survival fit log_likelihood", state.log_likelihood)?;
580    crate::estimate::ensure_finite_scalar("survival fit deviance", state.deviance)?;
581    crate::estimate::ensure_finite_scalar("survival fit penalty", state.penalty_term)?;
582    crate::estimate::ensure_finite_scalar("survival fit reml_score", reml_score)?;
583    crate::estimate::ensure_finite_scalar("survival fit gradient_norm", summary.lastgradient_norm)?;
584    crate::estimate::ensure_finite_scalar("survival fit max_abs_eta", summary.max_abs_eta)?;
585
586    UnifiedFitResult::try_from_parts(crate::estimate::UnifiedFitResultParts {
587        blocks: vec![crate::estimate::FittedBlock {
588            beta: beta.clone(),
589            role: crate::estimate::BlockRole::Mean,
590            edf: 0.0,
591            lambdas: lambdas.clone(),
592        }],
593        log_lambdas,
594        lambdas,
595        likelihood_family: Some(LikelihoodFamily::RoystonParmar),
596        likelihood_scale: crate::types::LikelihoodScaleMetadata::Unspecified,
597        log_likelihood_normalization: crate::types::LogLikelihoodNormalization::UserProvided,
598        log_likelihood: state.log_likelihood,
599        deviance: state.deviance,
600        reml_score,
601        stable_penalty_term: state.penalty_term,
602        penalized_objective: reml_score,
603        outer_iterations: summary.iterations,
604        outer_converged: true,
605        outer_gradient_norm: summary.lastgradient_norm,
606        standard_deviation: 1.0,
607        covariance_conditional: None,
608        covariance_corrected: None,
609        inference: None,
610        fitted_link: FittedLinkState::Standard(None),
611        geometry: None,
612        block_states: Vec::new(),
613        pirls_status: summary.status,
614        max_abs_eta: summary.max_abs_eta,
615        constraint_kkt: None,
616        artifacts: crate::estimate::FitArtifacts {
617            pirls: None,
618            ..Default::default()
619        },
620        inner_cycles: 0,
621    })
622    .map_err(|err| err.to_string())
623}
624
625fn fit_survival_transformation_model(
626    request: SurvivalTransformationFitRequest<'_>,
627) -> Result<SurvivalTransformationFitResult, String> {
628    use crate::survival::{MonotonicityPenalty, PenaltyBlock, PenaltyBlocks, SurvivalSpec};
629
630    let SurvivalTransformationFitRequest { data, spec } = request;
631    let mut baseline_cfg = spec.baseline_cfg.clone();
632    let covariate_design =
633        build_term_collection_design(data, &spec.covariate_spec).map_err(|err| err.to_string())?;
634    let resolvedspec =
635        crate::smooth::freeze_term_collection_from_design(&spec.covariate_spec, &covariate_design)
636            .map_err(|err| err.to_string())?;
637    let dense_cov_design = covariate_design.design.to_dense();
638    let p_cov = dense_cov_design.ncols();
639    let event_competing = Array1::<u8>::zeros(spec.event_target.len());
640    let exact_derivative_guard = survival_derivative_guard_for_likelihood(spec.likelihood_mode);
641
642    let build_working_model =
643        |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
644            let prepared = prepare_workflow_survival_time_stack(
645                &spec.age_entry,
646                &spec.age_exit,
647                candidate,
648                spec.likelihood_mode,
649                None,
650                spec.time_anchor,
651                exact_derivative_guard,
652                &spec.time_build,
653                spec.timewiggle.as_ref(),
654                None,
655            )?;
656            let mut eta_offset_entry = prepared.eta_offset_entry.clone();
657            let mut eta_offset_exit = prepared.eta_offset_exit.clone();
658            eta_offset_entry += &spec.covariate_offset;
659            eta_offset_exit += &spec.covariate_offset;
660            let p_time_total = prepared.time_design_exit.ncols();
661            let p = p_time_total + p_cov;
662            let mut penalty_blocks = Vec::<PenaltyBlock>::new();
663            for (idx, penalty) in prepared.time_penalties.iter().enumerate() {
664                if penalty.nrows() == p_time_total && penalty.ncols() == p_time_total {
665                    penalty_blocks.push(PenaltyBlock {
666                        matrix: penalty.clone(),
667                        lambda: spec.time_build.smooth_lambda.unwrap_or(1e-2),
668                        range: 0..p_time_total,
669                        nullspace_dim: prepared.time_nullspace_dims.get(idx).copied().unwrap_or(0),
670                    });
671                }
672            }
673            let ridge_range_start = if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull
674                && spec.time_build.basisname == "linear"
675                && spec.timewiggle.is_none()
676            {
677                1
678            } else {
679                0
680            };
681            if spec.ridge_lambda > 0.0 && p > ridge_range_start {
682                let dim = p - ridge_range_start;
683                let mut ridge = Array2::<f64>::zeros((dim, dim));
684                for d in 0..dim {
685                    ridge[[d, d]] = 1.0;
686                }
687                penalty_blocks.push(PenaltyBlock {
688                    matrix: ridge,
689                    lambda: spec.ridge_lambda,
690                    range: ridge_range_start..p,
691                    nullspace_dim: 0,
692                });
693            }
694            let dense_time_entry = prepared.time_design_entry.to_dense();
695            let dense_time_exit = prepared.time_design_exit.to_dense();
696            let dense_time_derivative = prepared.time_design_derivative.to_dense();
697            let mut model =
698                crate::families::royston_parmar::working_model_from_time_covariateshared(
699                    PenaltyBlocks::new(penalty_blocks.clone()),
700                    MonotonicityPenalty { tolerance: 0.0 },
701                    SurvivalSpec::Net,
702                    crate::families::royston_parmar::RoystonParmarSharedTimeCovariateInputs {
703                        age_entry: spec.age_entry.view(),
704                        age_exit: spec.age_exit.view(),
705                        event_target: spec.event_target.view(),
706                        event_competing: event_competing.view(),
707                        weights: spec.weights.view(),
708                        time_entry: dense_time_entry.view(),
709                        time_exit: dense_time_exit.view(),
710                        time_derivative: dense_time_derivative.view(),
711                        covariates: dense_cov_design.view(),
712                        monotonicity_constraint_rows: None,
713                        monotonicity_constraint_offsets: None,
714                        eta_offset_entry: Some(eta_offset_entry.view()),
715                        eta_offset_exit: Some(eta_offset_exit.view()),
716                        derivative_offset_exit: Some(prepared.derivative_offset_exit.view()),
717                    },
718                )
719                .map_err(|err| format!("failed to construct survival model: {err}"))?;
720            if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull {
721                model
722                    .set_structural_monotonicity(true, p_time_total)
723                    .map_err(|err| format!("failed to enable structural monotonicity: {err}"))?;
724            }
725            let mut beta0 = Array1::<f64>::zeros(p);
726            if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none()
727            {
728                let (scale, shape) = spec
729                    .weibull_seed
730                    .ok_or_else(|| "weibull survival fit missing scale/shape seed".to_string())?;
731                if p_time_total < 2 {
732                    return Err(format!(
733                        "weibull built-in time basis has {p_time_total} columns but needs 2 to seed scale/shape"
734                    ));
735                }
736                beta0[0] = -shape * scale.ln();
737                beta0[1] = shape;
738            }
739            let structural_lower_bounds =
740                if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull && p_time_total > 0 {
741                    let mut lb = Array1::from_elem(p, f64::NEG_INFINITY);
742                    for j in 0..p_time_total {
743                        lb[j] = 0.0;
744                        beta0[j] = 1e-4;
745                    }
746                    Some(lb)
747                } else {
748                    None
749                };
750            Ok::<_, String>((
751                prepared,
752                penalty_blocks,
753                beta0,
754                structural_lower_bounds,
755                model,
756            ))
757        };
758
759    if baseline_cfg.target != SurvivalBaselineTarget::Linear {
760        baseline_cfg = optimize_survival_baseline_config(
761            &baseline_cfg,
762            "workflow survival transformation baseline",
763            |candidate| {
764                let (_, _, beta0, structural_lower_bounds, mut model) =
765                    build_working_model(candidate)?;
766                let opts = crate::pirls::WorkingModelPirlsOptions {
767                    max_iterations: 400,
768                    convergence_tolerance: 1e-6,
769                    max_step_halving: 40,
770                    min_step_size: 1e-12,
771                    firth_bias_reduction: false,
772                    coefficient_lower_bounds: structural_lower_bounds,
773                    linear_constraints: None,
774                    initial_lm_lambda: None,
775                };
776                let summary = crate::pirls::runworking_model_pirls(
777                    &mut model,
778                    crate::types::Coefficients::new(beta0),
779                    &opts,
780                    |_| {},
781                )
782                .map_err(|err| format!("survival PIRLS failed: {err}"))?;
783                let beta = summary.beta.as_ref().to_owned();
784                let state = model.update_state(&beta).map_err(|err| {
785                    format!("failed to evaluate survival baseline candidate: {err}")
786                })?;
787                Ok(survival_working_reml_score(&state))
788            },
789        )?;
790    }
791
792    let (prepared, penalty_blocks, beta0, structural_lower_bounds, mut model) =
793        build_working_model(&baseline_cfg)?;
794    let opts = crate::pirls::WorkingModelPirlsOptions {
795        max_iterations: 400,
796        convergence_tolerance: 1e-6,
797        max_step_halving: 40,
798        min_step_size: 1e-12,
799        firth_bias_reduction: false,
800        coefficient_lower_bounds: structural_lower_bounds,
801        linear_constraints: None,
802        initial_lm_lambda: None,
803    };
804    let summary = crate::pirls::runworking_model_pirls(
805        &mut model,
806        crate::types::Coefficients::new(beta0),
807        &opts,
808        |_| {},
809    )
810    .map_err(|err| format!("survival PIRLS failed: {err}"))?;
811    match summary.status {
812        crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum => {
813        }
814        ref other => {
815            return Err(format!(
816                "survival PIRLS did not converge: status={other:?}, grad_norm={:.3e}, iterations={}, deviance={:.6e}",
817                summary.lastgradient_norm, summary.iterations, summary.state.deviance
818            ));
819        }
820    }
821    let beta = summary.beta.as_ref().to_owned();
822    let state = model
823        .update_state(&beta)
824        .map_err(|err| format!("failed to evaluate survival optimum: {err}"))?;
825    let lambdas = Array1::from_iter(penalty_blocks.iter().map(|block| block.lambda));
826    let fitted_baseline_cfg =
827        if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none() {
828            let time_beta = beta
829                .slice(s![..spec.time_build.x_exit_time.ncols()])
830                .to_owned();
831            fitted_weibull_baseline_from_linear_time_beta(&time_beta).ok_or_else(|| {
832                "failed to recover fitted Weibull scale/shape from the linear time coefficients"
833                    .to_string()
834            })?
835        } else {
836            baseline_cfg
837        };
838    let fit = survival_unified_fit_result(beta, lambdas, &summary, &state)?;
839
840    Ok(SurvivalTransformationFitResult {
841        fit,
842        resolvedspec,
843        baseline_cfg: fitted_baseline_cfg,
844        likelihood_mode: spec.likelihood_mode,
845        time_anchor: spec.time_anchor,
846        time_basisname: spec.time_build.basisname.clone(),
847        time_base_ncols: spec.time_build.x_exit_time.ncols(),
848        time_degree: spec.time_build.degree,
849        time_knots: spec.time_build.knots.clone(),
850        time_keep_cols: spec.time_build.keep_cols.clone(),
851        time_smooth_lambda: spec.time_build.smooth_lambda,
852        baseline_timewiggle: prepared.timewiggle_block,
853    })
854}
855
856fn fit_survival_location_scale_model(
857    request: SurvivalLocationScaleFitRequest<'_>,
858) -> Result<SurvivalLocationScaleFitResult, String> {
859    // Profile one coherent survival subproblem at a fixed inverse-link state:
860    // select/apply the link-wiggle basis for that state, then solve the full
861    // penalized location-scale fit on the resulting model.
862    fn profile_survival_location_scale(
863        data: ArrayView2<'_, f64>,
864        spec: SurvivalLocationScaleTermSpec,
865        wiggle: Option<LinkWiggleConfig>,
866        kappa_options: &SpatialLengthScaleOptimizationOptions,
867    ) -> Result<SurvivalLocationScaleProfile, String> {
868        let mut wiggle_knots = None;
869        let mut wiggle_degree = None;
870        let inverse_link = spec.inverse_link.clone();
871
872        let fit = if let Some(wiggle) = wiggle {
873            require_inverse_link_supports_joint_wiggle(&inverse_link, "survival link wiggle")?;
874            let mut pilot_spec = spec.clone();
875            pilot_spec.linkwiggle_block = None;
876            let pilot = fit_survival_location_scale_terms(data, pilot_spec, kappa_options)?;
877            let selected_wiggle_basis = select_survival_link_wiggle_basis_from_pilot(
878                &pilot,
879                &WiggleBlockConfig {
880                    degree: wiggle.degree,
881                    num_internal_knots: wiggle.num_internal_knots,
882                    penalty_order: 2,
883                    double_penalty: wiggle.double_penalty,
884                },
885                &wiggle.penalty_orders,
886            )?;
887            wiggle_knots = Some(selected_wiggle_basis.knots.clone());
888            wiggle_degree = Some(selected_wiggle_basis.degree);
889            fit_survival_location_scale_terms_with_selected_wiggle(
890                data,
891                spec,
892                selected_wiggle_basis,
893                kappa_options,
894            )?
895        } else {
896            fit_survival_location_scale_terms(data, spec, kappa_options)?
897        };
898
899        Ok(SurvivalLocationScaleProfile {
900            fit,
901            inverse_link,
902            wiggle_knots,
903            wiggle_degree,
904        })
905    }
906
907    fn profile_survival_location_scale_with_inverse_link(
908        data: ArrayView2<'_, f64>,
909        spec: &SurvivalLocationScaleTermSpec,
910        inverse_link: InverseLink,
911        wiggle: Option<LinkWiggleConfig>,
912        kappa_options: &SpatialLengthScaleOptimizationOptions,
913    ) -> Result<SurvivalLocationScaleProfile, String> {
914        let mut spec_at_link = spec.clone();
915        spec_at_link.inverse_link = inverse_link;
916        profile_survival_location_scale(data, spec_at_link, wiggle, kappa_options)
917    }
918
919    fn optimize_survival_inverse_link_profile(
920        data: ArrayView2<'_, f64>,
921        spec: &SurvivalLocationScaleTermSpec,
922        wiggle: Option<LinkWiggleConfig>,
923        kappa_options: &SpatialLengthScaleOptimizationOptions,
924    ) -> Result<SurvivalLocationScaleProfile, String> {
925        fn optimize_link_parameters<F, R>(
926            data: ArrayView2<'_, f64>,
927            spec: &SurvivalLocationScaleTermSpec,
928            kappa_options: &SpatialLengthScaleOptimizationOptions,
929            init: Array1<f64>,
930            name: &str,
931            final_wiggle: Option<LinkWiggleConfig>,
932            objective: F,
933            recover: R,
934        ) -> Result<SurvivalLocationScaleProfile, String>
935        where
936            F: FnMut(&Array1<f64>) -> Result<f64, EstimationError>,
937            R: Fn(&Array1<f64>) -> Option<InverseLink>,
938        {
939            use crate::solver::outer_strategy::{OuterProblem, SolverClass};
940            let dim = init.len();
941            // Inverse-link parameters (SAS epsilon/log_delta, BetaLogistic shape,
942            // Mixture rho) have no analytic ∂LAML/∂θ_link; route through the
943            // gated gradient-free CompassSearch variant rather than BFGS. Box
944            // bounds keep line-search probes inside a physically admissible
945            // region (|epsilon|, |log_delta| ≤ 6 gives the SAS link a finite
946            // range on both tails).
947            let lower = init.mapv(|v| v - 6.0);
948            let upper = init.mapv(|v| v + 6.0);
949            let problem = OuterProblem::new(dim)
950                .with_solver_class(SolverClass::AuxiliaryGradientFree)
951                .with_tolerance(1e-4)
952                .with_max_iter(240)
953                .with_bounds(lower, upper)
954                .with_heuristic_lambdas(init.to_vec());
955            let context = format!("survival inverse-link optimization ({name}, dim={dim})");
956            let mut obj = problem.build_objective(
957                objective,
958                |f: &mut F, rho: &ndarray::Array1<f64>| f(rho),
959                |_: &mut F, _: &ndarray::Array1<f64>| {
960                    Err(EstimationError::InvalidInput(
961                        "inverse-link aux optimizer: CompassSearch dispatch only \
962                         calls eval_cost; eval(gradient) is unreachable by \
963                         construction"
964                            .to_string(),
965                    ))
966                },
967                None::<fn(&mut F)>,
968                None::<
969                    fn(
970                        &mut F,
971                        &ndarray::Array1<f64>,
972                    )
973                        -> Result<crate::solver::outer_strategy::EfsEval, EstimationError>,
974                >,
975            );
976            let result = problem
977                .run(&mut obj, &context)
978                .map_err(|err| format!("{context} failed: {err}"))?;
979            let link = recover_converged_survival_inverse_link(result, &context, recover)?;
980            profile_survival_location_scale_with_inverse_link(
981                data,
982                spec,
983                link,
984                final_wiggle,
985                kappa_options,
986            )
987            .map_err(|err| format!("{context} final profiling failed: {err}"))
988        }
989
990        match spec.inverse_link.clone() {
991            InverseLink::Sas(state0) => {
992                let init = Array1::from_vec(vec![state0.epsilon, state0.log_delta]);
993                let wiggle_cfg = wiggle.clone();
994                optimize_link_parameters(
995                    data,
996                    spec,
997                    kappa_options,
998                    init,
999                    "SAS",
1000                    wiggle.clone(),
1001                    |theta: &Array1<f64>| {
1002                        let state = state_from_sasspec(SasLinkSpec {
1003                            initial_epsilon: theta[0],
1004                            initial_log_delta: theta[1],
1005                        })
1006                        .map_err(EstimationError::InvalidInput)?;
1007                        Ok(profile_survival_location_scale_with_inverse_link(
1008                            data,
1009                            spec,
1010                            InverseLink::Sas(state),
1011                            wiggle_cfg.clone(),
1012                            kappa_options,
1013                        )
1014                        .map_err(EstimationError::InvalidInput)?
1015                        .objective())
1016                    },
1017                    |rho| {
1018                        state_from_sasspec(SasLinkSpec {
1019                            initial_epsilon: rho[0],
1020                            initial_log_delta: rho[1],
1021                        })
1022                        .ok()
1023                        .map(InverseLink::Sas)
1024                    },
1025                )
1026            }
1027            InverseLink::BetaLogistic(state0) => {
1028                let init = Array1::from_vec(vec![state0.epsilon, state0.log_delta]);
1029                let wiggle_cfg = wiggle.clone();
1030                optimize_link_parameters(
1031                    data,
1032                    spec,
1033                    kappa_options,
1034                    init,
1035                    "BetaLogistic",
1036                    wiggle.clone(),
1037                    |theta: &Array1<f64>| {
1038                        let state = state_from_beta_logisticspec(SasLinkSpec {
1039                            initial_epsilon: theta[0],
1040                            initial_log_delta: theta[1],
1041                        })
1042                        .map_err(EstimationError::InvalidInput)?;
1043                        Ok(profile_survival_location_scale_with_inverse_link(
1044                            data,
1045                            spec,
1046                            InverseLink::BetaLogistic(state),
1047                            wiggle_cfg.clone(),
1048                            kappa_options,
1049                        )
1050                        .map_err(EstimationError::InvalidInput)?
1051                        .objective())
1052                    },
1053                    |rho| {
1054                        state_from_beta_logisticspec(SasLinkSpec {
1055                            initial_epsilon: rho[0],
1056                            initial_log_delta: rho[1],
1057                        })
1058                        .ok()
1059                        .map(InverseLink::BetaLogistic)
1060                    },
1061                )
1062            }
1063            InverseLink::Mixture(state0) if !state0.rho.is_empty() => {
1064                let components = state0.components.clone();
1065                let components_recover = components.clone();
1066                let wiggle_cfg = wiggle.clone();
1067                optimize_link_parameters(
1068                    data,
1069                    spec,
1070                    kappa_options,
1071                    state0.rho.clone(),
1072                    "mixture",
1073                    wiggle.clone(),
1074                    move |rho: &Array1<f64>| {
1075                        let state = state_fromspec(&MixtureLinkSpec {
1076                            components: components.clone(),
1077                            initial_rho: rho.clone(),
1078                        })
1079                        .map_err(EstimationError::InvalidInput)?;
1080                        Ok(profile_survival_location_scale_with_inverse_link(
1081                            data,
1082                            spec,
1083                            InverseLink::Mixture(state),
1084                            wiggle_cfg.clone(),
1085                            kappa_options,
1086                        )
1087                        .map_err(EstimationError::InvalidInput)?
1088                        .objective())
1089                    },
1090                    move |rho| {
1091                        state_fromspec(&MixtureLinkSpec {
1092                            components: components_recover.clone(),
1093                            initial_rho: rho.to_owned(),
1094                        })
1095                        .ok()
1096                        .map(InverseLink::Mixture)
1097                    },
1098                )
1099            }
1100            _ => profile_survival_location_scale(data, spec.clone(), wiggle, kappa_options),
1101        }
1102    }
1103
1104    let profile = if request.optimize_inverse_link {
1105        optimize_survival_inverse_link_profile(
1106            request.data,
1107            &request.spec,
1108            request.wiggle.clone(),
1109            &request.kappa_options,
1110        )?
1111    } else {
1112        profile_survival_location_scale(
1113            request.data,
1114            request.spec.clone(),
1115            request.wiggle.clone(),
1116            &request.kappa_options,
1117        )?
1118    };
1119
1120    Ok(profile.into_result())
1121}
1122
1123fn fit_bernoulli_marginal_slope_model(
1124    request: BernoulliMarginalSlopeFitRequest<'_>,
1125) -> Result<BernoulliMarginalSlopeFitResult, String> {
1126    // Phase 4: at biobank scale, auto-install a stratified outer-score
1127    // subsample so outer-only score / Hessian sweeps run on ~K rows instead
1128    // of ~n. Inner-PIRLS and final-covariance passes still use full data
1129    // because they don't consult `outer_score_subsample`. The helper is a
1130    // no-op below the biobank threshold or when a subsample is already set.
1131    let mut options = request.options.clone();
1132    crate::families::marginal_slope_shared::inject_biobank_outer_subsample_from_arrays(
1133        &mut options,
1134        request.spec.z.as_slice().expect("z is contiguous"),
1135        request.spec.y.as_slice().expect("y is contiguous"),
1136    );
1137    fit_bernoulli_marginal_slope_terms(
1138        request.data,
1139        request.spec,
1140        &options,
1141        &request.kappa_options,
1142        &request.policy,
1143    )
1144}
1145
1146fn fit_survival_marginal_slope_model(
1147    request: SurvivalMarginalSlopeFitRequest<'_>,
1148) -> Result<SurvivalMarginalSlopeFitResult, String> {
1149    // Phase 4: see `fit_bernoulli_marginal_slope_model` above. Survival
1150    // stratifies on the event indicator (`event_target`), which is the
1151    // canonical {0,1} secondary class for right-censored survival.
1152    let mut options = request.options.clone();
1153    crate::families::marginal_slope_shared::inject_biobank_outer_subsample_from_arrays(
1154        &mut options,
1155        request.spec.z.as_slice().expect("z is contiguous"),
1156        request
1157            .spec
1158            .event_target
1159            .as_slice()
1160            .expect("event_target is contiguous"),
1161    );
1162    fit_survival_marginal_slope_terms(request.data, request.spec, &options, &request.kappa_options)
1163}
1164
1165fn fit_latent_survival_model(
1166    request: LatentSurvivalFitRequest<'_>,
1167) -> Result<LatentSurvivalTermFitResult, String> {
1168    fit_latent_survival_terms(
1169        request.data,
1170        request.spec,
1171        request.frailty,
1172        &request.options,
1173    )
1174}
1175
1176fn fit_latent_binary_model(
1177    request: LatentBinaryFitRequest<'_>,
1178) -> Result<LatentBinaryTermFitResult, String> {
1179    fit_latent_binary_terms(
1180        request.data,
1181        request.spec,
1182        request.frailty,
1183        &request.options,
1184    )
1185}
1186
1187fn fit_transformation_normal_model(
1188    request: TransformationNormalFitRequest<'_>,
1189) -> Result<TransformationNormalFitResult, String> {
1190    fit_transformation_normal(
1191        &request.response,
1192        &request.weights,
1193        &request.offset,
1194        request.data,
1195        &request.covariate_spec,
1196        &request.config,
1197        &request.options,
1198        &request.kappa_options,
1199        request.warm_start.as_ref(),
1200    )
1201}
1202
1203pub fn fit_model(request: FitRequest<'_>) -> Result<FitResult, String> {
1204    match request {
1205        FitRequest::Standard(request) => fit_standard_model(request).map(FitResult::Standard),
1206        FitRequest::GaussianLocationScale(request) => {
1207            fit_gaussian_location_scale_model(request).map(FitResult::GaussianLocationScale)
1208        }
1209        FitRequest::BinomialLocationScale(request) => {
1210            fit_binomial_location_scale_model(request).map(FitResult::BinomialLocationScale)
1211        }
1212        FitRequest::SurvivalLocationScale(request) => {
1213            fit_survival_location_scale_model(request).map(FitResult::SurvivalLocationScale)
1214        }
1215        FitRequest::SurvivalTransformation(request) => {
1216            fit_survival_transformation_model(request).map(FitResult::SurvivalTransformation)
1217        }
1218        FitRequest::BernoulliMarginalSlope(request) => {
1219            fit_bernoulli_marginal_slope_model(request).map(FitResult::BernoulliMarginalSlope)
1220        }
1221        FitRequest::SurvivalMarginalSlope(request) => {
1222            fit_survival_marginal_slope_model(request).map(FitResult::SurvivalMarginalSlope)
1223        }
1224        FitRequest::LatentSurvival(request) => {
1225            fit_latent_survival_model(request).map(FitResult::LatentSurvival)
1226        }
1227        FitRequest::LatentBinary(request) => {
1228            fit_latent_binary_model(request).map(FitResult::LatentBinary)
1229        }
1230        FitRequest::TransformationNormal(request) => {
1231            fit_transformation_normal_model(request).map(FitResult::TransformationNormal)
1232        }
1233    }
1234}
1235
1236// ---------------------------------------------------------------------------
1237// High-level formula-to-fit API
1238// ---------------------------------------------------------------------------
1239
1240use crate::families::family_meta::{family_to_string, is_binomial_family};
1241use crate::families::survival_construction::{
1242    SurvivalBaselineTarget, SurvivalLikelihoodMode, SurvivalTimeBasisConfig,
1243    add_survival_time_derivative_guard_offset, append_zero_tail_columns,
1244    build_latent_survival_baseline_offsets, build_survival_time_basis,
1245    build_survival_time_offsets_for_likelihood, build_survival_timewiggle_from_baseline,
1246    build_time_varying_survival_covariate_template, center_survival_time_designs_at_anchor,
1247    evaluate_survival_time_basis_row, initial_survival_baseline_config_for_fit,
1248    marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
1249    normalize_survival_time_pair, optimize_survival_baseline_config,
1250    optimize_survival_baseline_config_with_gradient, parse_survival_distribution,
1251    parse_survival_likelihood_mode, parse_survival_time_basis_config, positive_survival_time_seed,
1252    require_structural_survival_time_basis, resolve_survival_time_anchor_value,
1253    resolved_survival_time_basis_config_from_build, survival_derivative_guard_for_likelihood,
1254};
1255use crate::families::survival_location_scale::{
1256    SurvivalCovariateTermBlockTemplate, TimeBlockInput, TimeWiggleBlockInput,
1257    residual_distribution_inverse_link,
1258};
1259use crate::inference::data::EncodedDataset as Dataset;
1260use crate::inference::formula_dsl::{
1261    LinkChoice, LinkWiggleFormulaSpec, ParsedFormula, ParsedTerm, effectivelinkwiggle_formulaspec,
1262    parse_formula, parse_link_choice, parse_matching_auxiliary_formula, parse_surv_response,
1263    require_inverse_link_supports_joint_wiggle, validate_marginal_slope_z_column_exclusion,
1264};
1265use crate::term_builder::{
1266    build_termspec, column_map_with_alias, enable_scale_dimensions, resolve_role_col,
1267};
1268
1269/// Non-formula configuration for model fitting. All fields have sensible defaults.
1270#[derive(Clone, Debug)]
1271pub struct FitConfig {
1272    /// Family: "gaussian", "binomial", "poisson", "gamma", or None for auto-detect.
1273    pub family: Option<String>,
1274    /// Link: "identity", "logit", "probit", "cloglog", "sas", "beta-logistic", or None.
1275    pub link: Option<String>,
1276    /// Whether to use flexible (wiggle-augmented) link.
1277    pub flexible_link: bool,
1278    /// Optional additive offset column for the primary linear predictor.
1279    pub offset_column: Option<String>,
1280    /// Optional additive offset column for the noise/log-scale predictor.
1281    pub noise_offset_column: Option<String>,
1282    /// Optional family-level frailty modifier.
1283    pub frailty: Option<FrailtySpec>,
1284
1285    // Survival-specific
1286    /// Baseline target: "linear", "weibull", "gompertz", "gompertz-makeham".
1287    pub baseline_target: String,
1288    pub baseline_scale: Option<f64>,
1289    pub baseline_shape: Option<f64>,
1290    pub baseline_rate: Option<f64>,
1291    pub baseline_makeham: Option<f64>,
1292    /// Time basis: "ispline" or "none".
1293    pub time_basis: String,
1294    pub time_degree: usize,
1295    pub time_num_internal_knots: usize,
1296    pub time_smooth_lambda: f64,
1297    /// Survival likelihood mode: "location-scale", "transformation", "weibull",
1298    /// "marginal-slope", "latent", or "latent-binary".
1299    pub survival_likelihood: String,
1300    /// Residual distribution: "gaussian", "logistic", "gumbel".
1301    pub survival_distribution: String,
1302    pub threshold_time_k: Option<usize>,
1303    pub threshold_time_degree: usize,
1304    pub sigma_time_k: Option<usize>,
1305    pub sigma_time_degree: usize,
1306
1307    // Location-scale (GAMLSS)
1308    /// If set, fit a location-scale model with this formula for the noise parameter.
1309    pub noise_formula: Option<String>,
1310
1311    // Marginal-slope
1312    /// Formula for the log-slope model (survival marginal-slope or Bernoulli marginal-slope).
1313    pub logslope_formula: Option<String>,
1314    /// Column name for the z (exposure/dose) variable in marginal-slope models.
1315    pub z_column: Option<String>,
1316    /// Optional non-negative per-row training weights column.
1317    pub weight_column: Option<String>,
1318
1319    // Fitting options
1320    pub scale_dimensions: bool,
1321    pub ridge_lambda: f64,
1322
1323    /// Route the fit through the transformation-normal family.  When set, the
1324    /// formula terms are treated as the covariate side of the transformation
1325    /// model and the response basis is built internally.  Incompatible with
1326    /// `noise_formula` and with `Surv(...)` responses.
1327    pub transformation_normal: bool,
1328
1329    /// Enable Firth bias reduction for standard single-parameter families.
1330    pub firth: bool,
1331
1332    /// Optional override of the [`crate::resource::ResourcePolicy`] used when
1333    /// planning spatial bases (TPS / Matern / Duchon) during term construction.
1334    /// When `None`, the default-library policy is used.
1335    pub resource_policy: Option<crate::resource::ResourcePolicy>,
1336}
1337
1338impl Default for FitConfig {
1339    fn default() -> Self {
1340        Self {
1341            family: None,
1342            link: None,
1343            flexible_link: false,
1344            offset_column: None,
1345            noise_offset_column: None,
1346            frailty: None,
1347            baseline_target: "linear".into(),
1348            baseline_scale: None,
1349            baseline_shape: None,
1350            baseline_rate: None,
1351            baseline_makeham: None,
1352            time_basis: "ispline".into(),
1353            time_degree: 3,
1354            time_num_internal_knots: 8,
1355            time_smooth_lambda: 1e-2,
1356            survival_likelihood: "location-scale".into(),
1357            survival_distribution: "gaussian".into(),
1358            threshold_time_k: None,
1359            threshold_time_degree: 3,
1360            sigma_time_k: None,
1361            sigma_time_degree: 3,
1362            noise_formula: None,
1363            logslope_formula: None,
1364            z_column: None,
1365            weight_column: None,
1366            scale_dimensions: false,
1367            ridge_lambda: 1e-6,
1368            transformation_normal: false,
1369            firth: false,
1370            resource_policy: None,
1371        }
1372    }
1373}
1374
1375/// Resolve the [`crate::resource::ResourcePolicy`] backing term construction
1376/// for a given [`FitConfig`].  Returns the configured override when present,
1377/// otherwise the default-library policy.
1378fn resolved_resource_policy(config: &FitConfig) -> crate::resource::ResourcePolicy {
1379    config
1380        .resource_policy
1381        .clone()
1382        .unwrap_or_else(crate::resource::ResourcePolicy::default_library)
1383}
1384
1385/// The result of materializing a formula + config against a dataset.
1386pub struct MaterializedModel<'a> {
1387    pub request: FitRequest<'a>,
1388    pub inference_notes: Vec<String>,
1389}
1390
1391/// Parse, materialize, and fit a model in one call.
1392pub fn fit_from_formula(
1393    formula: &str,
1394    data: &Dataset,
1395    config: &FitConfig,
1396) -> Result<FitResult, String> {
1397    let mat = materialize(formula, data, config)?;
1398    fit_model(mat.request)
1399}
1400
1401/// Parse a formula, resolve it against a dataset, and produce a ready-to-fit `FitRequest`.
1402pub fn materialize<'a>(
1403    formula: &str,
1404    data: &'a Dataset,
1405    config: &FitConfig,
1406) -> Result<MaterializedModel<'a>, String> {
1407    let parsed = parse_formula(formula)?;
1408    let col_map = data.column_map();
1409
1410    if let Some((entry_col, exit_col, event_col)) = parse_surv_response(&parsed.response)? {
1411        if config.transformation_normal {
1412            return Err(
1413                "transformation_normal cannot be combined with a Surv(...) response".to_string(),
1414            );
1415        }
1416        materialize_survival(
1417            &parsed, data, &col_map, config, &entry_col, &exit_col, &event_col,
1418        )
1419    } else if config.transformation_normal {
1420        if config.noise_formula.is_some() {
1421            return Err("transformation_normal cannot be combined with noise_formula".to_string());
1422        }
1423        materialize_transformation_normal(&parsed, data, &col_map, config)
1424    } else if config.logslope_formula.is_some() || config.z_column.is_some() {
1425        materialize_bernoulli_marginal_slope(&parsed, data, &col_map, config)
1426    } else if config.noise_formula.is_some() {
1427        materialize_location_scale(&parsed, data, &col_map, config)
1428    } else {
1429        materialize_standard(&parsed, data, &col_map, config)
1430    }
1431}
1432
1433/// Detect whether a response column is binary (0/1 only).
1434pub fn is_binary_response(y: ArrayView1<'_, f64>) -> bool {
1435    if y.is_empty() {
1436        return false;
1437    }
1438    y.iter()
1439        .all(|v| (*v - 0.0).abs() < 1e-12 || (*v - 1.0).abs() < 1e-12)
1440}
1441
1442/// Resolve a family from an optional name, optional link choice, and response data.
1443pub fn resolve_family(
1444    family: Option<&str>,
1445    link_choice: Option<&LinkChoice>,
1446    y: ArrayView1<'_, f64>,
1447) -> Result<LikelihoodFamily, String> {
1448    let explicit = family.and_then(|name| match name.to_ascii_lowercase().as_str() {
1449        "gaussian" => Some(LikelihoodFamily::GaussianIdentity),
1450        "binomial" | "binomial-logit" => Some(LikelihoodFamily::BinomialLogit),
1451        "binomial-probit" => Some(LikelihoodFamily::BinomialProbit),
1452        "binomial-cloglog" => Some(LikelihoodFamily::BinomialCLogLog),
1453        "latent-cloglog-binomial" => Some(LikelihoodFamily::BinomialLatentCLogLog),
1454        "poisson" => Some(LikelihoodFamily::PoissonLog),
1455        "gamma" => Some(LikelihoodFamily::GammaLog),
1456        _ => None,
1457    });
1458
1459    if let Some(choice) = link_choice {
1460        let from_link = if choice.mixture_components.is_some() {
1461            LikelihoodFamily::BinomialMixture
1462        } else {
1463            match choice.link {
1464                LinkFunction::Identity => LikelihoodFamily::GaussianIdentity,
1465                LinkFunction::Log => {
1466                    if y.iter()
1467                        .all(|&yi| yi.is_finite() && yi >= 0.0 && (yi - yi.round()).abs() <= 1e-9)
1468                    {
1469                        LikelihoodFamily::PoissonLog
1470                    } else {
1471                        LikelihoodFamily::GammaLog
1472                    }
1473                }
1474                LinkFunction::Logit => LikelihoodFamily::BinomialLogit,
1475                LinkFunction::Probit => LikelihoodFamily::BinomialProbit,
1476                LinkFunction::CLogLog => LikelihoodFamily::BinomialCLogLog,
1477                LinkFunction::Sas => LikelihoodFamily::BinomialSas,
1478                LinkFunction::BetaLogistic => LikelihoodFamily::BinomialBetaLogistic,
1479            }
1480        };
1481        if let Some(explicit_family) = explicit {
1482            if explicit_family != from_link {
1483                return Err(format!(
1484                    "family '{}' conflicts with link",
1485                    family_to_string(explicit_family)
1486                ));
1487            }
1488        }
1489        return Ok(from_link);
1490    }
1491
1492    if let Some(f) = explicit {
1493        return Ok(f);
1494    }
1495
1496    // Auto-detect
1497    if is_binary_response(y) {
1498        Ok(LikelihoodFamily::BinomialLogit)
1499    } else {
1500        Ok(LikelihoodFamily::GaussianIdentity)
1501    }
1502}
1503
1504// ---------------------------------------------------------------------------
1505// Internal helpers
1506// ---------------------------------------------------------------------------
1507
1508fn build_termspec_with_geometry(
1509    terms: &[ParsedTerm],
1510    data: &Dataset,
1511    col_map: &HashMap<String, usize>,
1512    inference_notes: &mut Vec<String>,
1513    scale_dimensions: bool,
1514    policy: &crate::resource::ResourcePolicy,
1515) -> Result<TermCollectionSpec, String> {
1516    let mut spec = build_termspec(terms, data, col_map, inference_notes, policy)?;
1517    if scale_dimensions {
1518        enable_scale_dimensions(&mut spec);
1519    }
1520    Ok(spec)
1521}
1522
1523fn resolve_survival_marginal_slope_base_link(
1524    linkspec: Option<&crate::inference::formula_dsl::LinkFormulaSpec>,
1525) -> Result<InverseLink, String> {
1526    let Some(linkspec) = linkspec else {
1527        return Ok(InverseLink::Standard(LinkFunction::Probit));
1528    };
1529    let choice = parse_link_choice(Some(&linkspec.link), false)?
1530        .ok_or_else(|| "invalid survival marginal-slope link".to_string())?;
1531    if choice.mixture_components.is_some() {
1532        return Err(
1533            "survival marginal-slope currently supports only link(type=probit)".to_string(),
1534        );
1535    }
1536    match choice.link {
1537        LinkFunction::Probit => Ok(InverseLink::Standard(LinkFunction::Probit)),
1538        other => Err(format!(
1539            "survival marginal-slope currently supports only link(type=probit), got {other:?}"
1540        )),
1541    }
1542}
1543
1544struct PreparedWorkflowSurvivalTimeStack {
1545    eta_offset_entry: Array1<f64>,
1546    eta_offset_exit: Array1<f64>,
1547    derivative_offset_exit: Array1<f64>,
1548    unloaded_mass_entry: Array1<f64>,
1549    unloaded_mass_exit: Array1<f64>,
1550    unloaded_hazard_exit: Array1<f64>,
1551    time_design_entry: crate::matrix::DesignMatrix,
1552    time_design_exit: crate::matrix::DesignMatrix,
1553    time_design_derivative: crate::matrix::DesignMatrix,
1554    time_penalties: Vec<Array2<f64>>,
1555    time_nullspace_dims: Vec<usize>,
1556    timewiggle_block: Option<TimeWiggleBlockInput>,
1557}
1558
1559fn prepare_workflow_survival_time_stack(
1560    age_entry: &Array1<f64>,
1561    age_exit: &Array1<f64>,
1562    baseline_cfg: &crate::families::survival_construction::SurvivalBaselineConfig,
1563    likelihood_mode: SurvivalLikelihoodMode,
1564    inverse_link: Option<&InverseLink>,
1565    time_anchor: f64,
1566    derivative_guard: f64,
1567    time_build: &crate::families::survival_construction::SurvivalTimeBuildOutput,
1568    effective_timewiggle: Option<&LinkWiggleFormulaSpec>,
1569    latent_loading: Option<crate::families::lognormal_kernel::HazardLoading>,
1570) -> Result<PreparedWorkflowSurvivalTimeStack, String> {
1571    let (
1572        mut eta_offset_entry,
1573        mut eta_offset_exit,
1574        mut derivative_offset_exit,
1575        unloaded_mass_entry,
1576        unloaded_mass_exit,
1577        unloaded_hazard_exit,
1578    ) = if let Some(loading) = latent_loading {
1579        let offsets =
1580            build_latent_survival_baseline_offsets(age_entry, age_exit, baseline_cfg, loading)?;
1581        (
1582            offsets.loaded_eta_entry,
1583            offsets.loaded_eta_exit,
1584            offsets.loaded_derivative_exit,
1585            offsets.unloaded_mass_entry,
1586            offsets.unloaded_mass_exit,
1587            offsets.unloaded_hazard_exit,
1588        )
1589    } else {
1590        let (eta_offset_entry, eta_offset_exit, derivative_offset_exit) =
1591            build_survival_time_offsets_for_likelihood(
1592                age_entry,
1593                age_exit,
1594                baseline_cfg,
1595                likelihood_mode,
1596                inverse_link,
1597            )?;
1598        let n = age_entry.len();
1599        (
1600            eta_offset_entry,
1601            eta_offset_exit,
1602            derivative_offset_exit,
1603            Array1::zeros(n),
1604            Array1::zeros(n),
1605            Array1::zeros(n),
1606        )
1607    };
1608    add_survival_time_derivative_guard_offset(
1609        age_entry,
1610        age_exit,
1611        time_anchor,
1612        derivative_guard,
1613        &mut eta_offset_entry,
1614        &mut eta_offset_exit,
1615        &mut derivative_offset_exit,
1616    )?;
1617    let timewiggle_build = if let Some(cfg) = effective_timewiggle {
1618        Some(build_survival_timewiggle_from_baseline(
1619            &eta_offset_entry,
1620            &eta_offset_exit,
1621            &derivative_offset_exit,
1622            cfg,
1623        )?)
1624    } else {
1625        None
1626    };
1627    let mut time_design_entry = time_build.x_entry_time.clone();
1628    let mut time_design_exit = time_build.x_exit_time.clone();
1629    let mut time_design_derivative = time_build.x_derivative_time.clone();
1630    let mut time_penalties = time_build.penalties.clone();
1631    let mut time_nullspace_dims = time_build.nullspace_dims.clone();
1632    let mut timewiggle_block = None;
1633    if let Some(wiggle) = timewiggle_build.as_ref() {
1634        let p_base = time_design_exit.ncols();
1635        append_zero_tail_columns(
1636            &mut time_design_entry,
1637            &mut time_design_exit,
1638            &mut time_design_derivative,
1639            wiggle.ncols,
1640        );
1641        for (idx, penalty) in wiggle.penalties.iter().enumerate() {
1642            let mut embedded = Array2::<f64>::zeros((p_base + wiggle.ncols, p_base + wiggle.ncols));
1643            embedded
1644                .slice_mut(s![
1645                    p_base..p_base + wiggle.ncols,
1646                    p_base..p_base + wiggle.ncols
1647                ])
1648                .assign(penalty);
1649            time_penalties.push(embedded);
1650            time_nullspace_dims.push(wiggle.nullspace_dims.get(idx).copied().unwrap_or(0));
1651        }
1652        timewiggle_block = Some(TimeWiggleBlockInput {
1653            knots: wiggle.knots.clone(),
1654            degree: wiggle.degree,
1655            ncols: wiggle.ncols,
1656        });
1657    }
1658    Ok(PreparedWorkflowSurvivalTimeStack {
1659        eta_offset_entry,
1660        eta_offset_exit,
1661        derivative_offset_exit,
1662        unloaded_mass_entry,
1663        unloaded_mass_exit,
1664        unloaded_hazard_exit,
1665        time_design_entry,
1666        time_design_exit,
1667        time_design_derivative,
1668        time_penalties,
1669        time_nullspace_dims,
1670        timewiggle_block,
1671    })
1672}
1673
1674fn resolve_continuous_column(
1675    data: &Dataset,
1676    col_map: &HashMap<String, usize>,
1677    column_name: &str,
1678    role: &str,
1679) -> Result<Array1<f64>, String> {
1680    let col_idx = resolve_role_col(col_map, column_name, role)?;
1681    let values = data.values.column(col_idx).to_owned();
1682    for (row_idx, value) in values.iter().enumerate() {
1683        if !value.is_finite() {
1684            return Err(format!(
1685                "{role} column '{column_name}' contains non-finite value at row {row_idx}: {value}"
1686            ));
1687        }
1688    }
1689    Ok(values)
1690}
1691
1692pub fn resolve_offset_column(
1693    data: &Dataset,
1694    col_map: &HashMap<String, usize>,
1695    column_name: Option<&str>,
1696) -> Result<Array1<f64>, String> {
1697    let Some(column_name) = column_name else {
1698        return Ok(Array1::zeros(data.values.nrows()));
1699    };
1700    resolve_continuous_column(data, col_map, column_name, "offset")
1701}
1702
1703pub fn resolve_weight_column(
1704    data: &Dataset,
1705    col_map: &HashMap<String, usize>,
1706    column_name: Option<&str>,
1707) -> Result<Array1<f64>, String> {
1708    let Some(column_name) = column_name else {
1709        return Ok(Array1::ones(data.values.nrows()));
1710    };
1711    let values = resolve_continuous_column(data, col_map, column_name, "weights")?;
1712    for (row_idx, value) in values.iter().enumerate() {
1713        if *value < 0.0 {
1714            return Err(format!(
1715                "weights column '{column_name}' must be non-negative; found {value} at row {row_idx}"
1716            ));
1717        }
1718    }
1719    Ok(values)
1720}
1721
1722fn materialize_standard<'a>(
1723    parsed: &ParsedFormula,
1724    data: &'a Dataset,
1725    col_map: &HashMap<String, usize>,
1726    config: &FitConfig,
1727) -> Result<MaterializedModel<'a>, String> {
1728    if config.noise_offset_column.is_some() {
1729        return Err(
1730            "noise_offset_column requires a location-scale model with noise_formula".to_string(),
1731        );
1732    }
1733    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
1734    let y = data.values.column(y_col).to_owned();
1735    let mut inference_notes = Vec::new();
1736
1737    let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
1738    let family = resolve_family(config.family.as_deref(), link_choice.as_ref(), y.view())?;
1739
1740    let effective_linkwiggle =
1741        effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
1742
1743    let policy = resolved_resource_policy(config);
1744    let spec = build_termspec_with_geometry(
1745        &parsed.terms,
1746        data,
1747        col_map,
1748        &mut inference_notes,
1749        config.scale_dimensions,
1750        &policy,
1751    )?;
1752
1753    let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
1754    let offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
1755    let latent_cloglog = if matches!(family, LikelihoodFamily::BinomialLatentCLogLog) {
1756        let sigma = match config.frailty.clone().unwrap_or(FrailtySpec::None) {
1757            FrailtySpec::HazardMultiplier {
1758                sigma_fixed: Some(sigma),
1759                loading: crate::families::lognormal_kernel::HazardLoading::Full,
1760            } => sigma,
1761            FrailtySpec::HazardMultiplier {
1762                sigma_fixed: Some(_),
1763                loading,
1764            } => {
1765                return Err(format!(
1766                    "latent-cloglog-binomial requires HazardLoading::Full, got {loading:?}"
1767                ));
1768            }
1769            FrailtySpec::HazardMultiplier {
1770                sigma_fixed: None, ..
1771            } => {
1772                return Err(
1773                    "latent-cloglog-binomial currently requires a fixed hazard-multiplier sigma"
1774                        .to_string(),
1775                );
1776            }
1777            FrailtySpec::GaussianShift { .. } => {
1778                return Err(
1779                    "latent-cloglog-binomial does not support GaussianShift frailty".to_string(),
1780                );
1781            }
1782            FrailtySpec::None => {
1783                return Err(
1784                    "latent-cloglog-binomial requires config.frailty=HazardMultiplier with a fixed sigma"
1785                        .to_string(),
1786                );
1787            }
1788        };
1789        Some(
1790            LatentCLogLogState::new(sigma)
1791                .map_err(|e| format!("invalid latent_cloglog state: {e}"))?,
1792        )
1793    } else {
1794        if config.frailty.is_some() {
1795            return Err(format!(
1796                "config.frailty is not supported for standard family {:?}; use a frailty-aware family instead",
1797                family
1798            ));
1799        }
1800        None
1801    };
1802    let options = FitOptions {
1803        latent_cloglog,
1804        mixture_link: None,
1805        optimize_mixture: false,
1806        sas_link: None,
1807        optimize_sas: false,
1808        compute_inference: true,
1809        max_iter: 200,
1810        tol: 1e-7,
1811        nullspace_dims: vec![],
1812        linear_constraints: None,
1813        firth_bias_reduction: config.firth,
1814        adaptive_regularization: None,
1815        penalty_shrinkage_floor: Some(1e-6),
1816        rho_prior: Default::default(),
1817        kronecker_penalty_system: None,
1818        kronecker_factored: None,
1819    };
1820    let kappa_options = SpatialLengthScaleOptimizationOptions::default();
1821
1822    let wiggle = effective_linkwiggle.as_ref().and_then(|cfg| {
1823        if !is_binomial_family(family) {
1824            return None;
1825        }
1826        let link_kind = link_choice
1827            .as_ref()
1828            .map(|c| InverseLink::Standard(c.link))
1829            .unwrap_or_else(|| {
1830                if let Some(state) = latent_cloglog {
1831                    InverseLink::LatentCLogLog(state)
1832                } else {
1833                    InverseLink::Standard(LinkFunction::Logit)
1834                }
1835            });
1836        Some(StandardBinomialWiggleConfig {
1837            link_kind,
1838            wiggle: LinkWiggleConfig {
1839                degree: cfg.degree,
1840                num_internal_knots: cfg.num_internal_knots,
1841                penalty_orders: cfg.penalty_orders.clone(),
1842                double_penalty: cfg.double_penalty,
1843            },
1844        })
1845    });
1846
1847    Ok(MaterializedModel {
1848        request: FitRequest::Standard(StandardFitRequest {
1849            data: data.values.view(),
1850            y,
1851            weights,
1852            offset,
1853            spec,
1854            family,
1855            options,
1856            kappa_options,
1857            wiggle,
1858            wiggle_options: None,
1859        }),
1860        inference_notes,
1861    })
1862}
1863
1864fn materialize_bernoulli_marginal_slope<'a>(
1865    parsed: &ParsedFormula,
1866    data: &'a Dataset,
1867    col_map: &HashMap<String, usize>,
1868    config: &FitConfig,
1869) -> Result<MaterializedModel<'a>, String> {
1870    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
1871    let y = data.values.column(y_col).to_owned();
1872
1873    if !is_binary_response(y.view()) {
1874        return Err("Bernoulli marginal-slope requires a binary {0,1} response".to_string());
1875    }
1876    if config.noise_formula.is_some() {
1877        return Err("Bernoulli marginal-slope cannot also use noise_formula".to_string());
1878    }
1879
1880    let logslope_formula = config
1881        .logslope_formula
1882        .as_deref()
1883        .ok_or_else(|| "Bernoulli marginal-slope requires logslope_formula".to_string())?;
1884    let z_column = config
1885        .z_column
1886        .as_deref()
1887        .ok_or_else(|| "Bernoulli marginal-slope requires z_column".to_string())?;
1888
1889    let (_, parsed_logslope) =
1890        parse_matching_auxiliary_formula(logslope_formula, &parsed.response, "logslope_formula")?;
1891    if parsed_logslope.linkspec.is_some() {
1892        return Err("link(...) is not supported inside logslope_formula".to_string());
1893    }
1894    validate_marginal_slope_z_column_exclusion(
1895        parsed,
1896        &parsed_logslope,
1897        z_column,
1898        "Bernoulli marginal-slope",
1899        "logslope_formula",
1900    )?;
1901
1902    let mut inference_notes = Vec::new();
1903    let policy = resolved_resource_policy(config);
1904    let aliased_col_map = column_map_with_alias(col_map, "z", z_column);
1905    let marginalspec = build_termspec_with_geometry(
1906        &parsed.terms,
1907        data,
1908        &aliased_col_map,
1909        &mut inference_notes,
1910        config.scale_dimensions,
1911        &policy,
1912    )?;
1913    let logslopespec = build_termspec_with_geometry(
1914        &parsed_logslope.terms,
1915        data,
1916        &aliased_col_map,
1917        &mut inference_notes,
1918        config.scale_dimensions,
1919        &policy,
1920    )?;
1921    let z_idx = resolve_role_col(col_map, z_column, "z")?;
1922    let z = data.values.column(z_idx).to_owned();
1923    let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
1924    let marginal_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
1925    let logslope_offset =
1926        resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
1927    let routing = route_marginal_slope_deviation_blocks(
1928        parsed.linkwiggle.as_ref(),
1929        parsed_logslope.linkwiggle.as_ref(),
1930    )?;
1931    let spec = BernoulliMarginalSlopeTermSpec {
1932        y,
1933        weights,
1934        z,
1935        base_link: InverseLink::Standard(LinkFunction::Probit),
1936        marginalspec,
1937        logslopespec,
1938        marginal_offset,
1939        logslope_offset,
1940        frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
1941        score_warp: routing.score_warp,
1942        link_dev: routing.link_dev,
1943        latent_z_policy: Default::default(),
1944    };
1945
1946    Ok(MaterializedModel {
1947        request: FitRequest::BernoulliMarginalSlope(BernoulliMarginalSlopeFitRequest {
1948            data: data.values.view(),
1949            spec,
1950            options: BlockwiseFitOptions {
1951                compute_covariance: true,
1952                ..Default::default()
1953            },
1954            kappa_options: SpatialLengthScaleOptimizationOptions::default(),
1955            policy,
1956        }),
1957        inference_notes,
1958    })
1959}
1960
1961fn materialize_survival<'a>(
1962    parsed: &ParsedFormula,
1963    data: &'a Dataset,
1964    col_map: &HashMap<String, usize>,
1965    config: &FitConfig,
1966    entry_col: &str,
1967    exit_col: &str,
1968    event_col: &str,
1969) -> Result<MaterializedModel<'a>, String> {
1970    let mut inference_notes = Vec::new();
1971
1972    // Extract columns
1973    let entry_idx = resolve_role_col(col_map, entry_col, "entry")?;
1974    let exit_idx = resolve_role_col(col_map, exit_col, "exit")?;
1975    let event_idx = resolve_role_col(col_map, event_col, "event")?;
1976    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1977    let n = data.values.nrows();
1978    let event = data.values.column(event_idx).to_owned();
1979    let pairs: Result<Vec<(f64, f64)>, String> = (0..n)
1980        .into_par_iter()
1981        .map(|i| {
1982            normalize_survival_time_pair(data.values[[i, entry_idx]], data.values[[i, exit_idx]], i)
1983        })
1984        .collect();
1985    let pairs = pairs?;
1986    let mut age_entry = Array1::<f64>::zeros(n);
1987    let mut age_exit = Array1::<f64>::zeros(n);
1988    for (i, (e, x)) in pairs.into_iter().enumerate() {
1989        age_entry[i] = e;
1990        age_exit[i] = x;
1991    }
1992
1993    let survival_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
1994    if parsed.linkwiggle.is_some()
1995        && !matches!(
1996            survival_mode,
1997            SurvivalLikelihoodMode::LocationScale | SurvivalLikelihoodMode::MarginalSlope
1998        )
1999    {
2000        return Err(format!(
2001            "linkwiggle(...) is not defined for survival_likelihood='{}'",
2002            config.survival_likelihood
2003        ));
2004    }
2005    if parsed.linkspec.is_some()
2006        && matches!(
2007            survival_mode,
2008            SurvivalLikelihoodMode::Transformation
2009                | SurvivalLikelihoodMode::Weibull
2010                | SurvivalLikelihoodMode::Latent
2011                | SurvivalLikelihoodMode::LatentBinary
2012        )
2013    {
2014        return Err(format!(
2015            "link(...) is not implemented for survival_likelihood='{}'",
2016            config.survival_likelihood
2017        ));
2018    }
2019    let effective_timewiggle = parsed.timewiggle.clone();
2020    let baseline_target_raw = match survival_mode {
2021        SurvivalLikelihoodMode::Weibull if effective_timewiggle.is_some() => "weibull",
2022        SurvivalLikelihoodMode::Weibull => "linear",
2023        _ => &config.baseline_target,
2024    };
2025    let baseline_cfg = initial_survival_baseline_config_for_fit(
2026        baseline_target_raw,
2027        config.baseline_scale,
2028        config.baseline_shape,
2029        config.baseline_rate,
2030        config.baseline_makeham,
2031        &age_exit,
2032    )?;
2033    if matches!(
2034        survival_mode,
2035        SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
2036    ) && baseline_cfg.target == SurvivalBaselineTarget::Linear
2037    {
2038        return Err(
2039            "latent hazard-window families require a non-linear scalar baseline target; use baseline_target weibull, gompertz, or gompertz-makeham"
2040                .to_string(),
2041        );
2042    }
2043    let time_cfg = if effective_timewiggle.is_some() {
2044        // Match the CLI path: the parametric baseline plus timewiggle supplies
2045        // the time structure, so the base time basis is disabled.
2046        SurvivalTimeBasisConfig::None
2047    } else if survival_mode == SurvivalLikelihoodMode::Weibull {
2048        SurvivalTimeBasisConfig::Linear
2049    } else {
2050        parse_survival_time_basis_config(
2051            &config.time_basis,
2052            config.time_degree,
2053            config.time_num_internal_knots,
2054            config.time_smooth_lambda,
2055        )?
2056    };
2057    let time_anchor = resolve_survival_time_anchor_value(&age_entry, None)?;
2058    let exact_derivative_guard = survival_derivative_guard_for_likelihood(survival_mode);
2059
2060    // Build time basis
2061    let mut time_build = build_survival_time_basis(
2062        &age_entry,
2063        &age_exit,
2064        time_cfg.clone(),
2065        Some((config.time_num_internal_knots, config.time_smooth_lambda)),
2066    )?;
2067    if survival_mode != SurvivalLikelihoodMode::Weibull && effective_timewiggle.is_none() {
2068        require_structural_survival_time_basis(&time_build.basisname, "workflow survival fitting")?;
2069    }
2070    let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
2071        &time_build.basisname,
2072        time_build.degree,
2073        time_build.knots.as_ref(),
2074        time_build.keep_cols.as_ref(),
2075        time_build.smooth_lambda,
2076    )?;
2077    let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
2078    center_survival_time_designs_at_anchor(
2079        &mut time_build.x_entry_time,
2080        &mut time_build.x_exit_time,
2081        &time_anchor_row,
2082    )?;
2083    if effective_timewiggle.is_some() && baseline_cfg.target == SurvivalBaselineTarget::Linear {
2084        return Err(
2085            "timewiggle requires a non-linear scalar survival baseline target; \
2086             use baseline_target weibull, gompertz, or gompertz-makeham"
2087                .to_string(),
2088        );
2089    }
2090
2091    let policy = resolved_resource_policy(config);
2092    let marginal_slope_aliased_col_map = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2093        Some(column_map_with_alias(
2094            col_map,
2095            "z",
2096            config.z_column.as_deref().ok_or_else(|| {
2097                "marginal-slope survival requires z_column in FitConfig".to_string()
2098            })?,
2099        ))
2100    } else {
2101        None
2102    };
2103    let termspec_col_map = marginal_slope_aliased_col_map.as_ref().unwrap_or(col_map);
2104    let termspec = build_termspec_with_geometry(
2105        &parsed.terms,
2106        data,
2107        termspec_col_map,
2108        &mut inference_notes,
2109        config.scale_dimensions,
2110        &policy,
2111    )?;
2112
2113    let residual_dist = parse_survival_distribution(&config.survival_distribution)?;
2114    let survival_inverse_link = residual_distribution_inverse_link(residual_dist);
2115    let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
2116    let effective_linkwiggle =
2117        effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
2118    let effective_linkwiggle_cfg = effective_linkwiggle.clone().map(|cfg| LinkWiggleConfig {
2119        degree: cfg.degree,
2120        num_internal_knots: cfg.num_internal_knots,
2121        penalty_orders: cfg.penalty_orders,
2122        double_penalty: cfg.double_penalty,
2123    });
2124
2125    let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
2126    let threshold_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
2127    let log_sigma_offset =
2128        resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
2129    let threshold_template = if let Some(k) = config.threshold_time_k {
2130        build_time_varying_survival_covariate_template(
2131            &age_entry,
2132            &age_exit,
2133            k,
2134            config.threshold_time_degree,
2135            "threshold",
2136        )?
2137    } else {
2138        SurvivalCovariateTermBlockTemplate::Static
2139    };
2140    let log_sigma_template = if let Some(k) = config.sigma_time_k {
2141        build_time_varying_survival_covariate_template(
2142            &age_entry,
2143            &age_exit,
2144            k,
2145            config.sigma_time_degree,
2146            "sigma",
2147        )?
2148    } else {
2149        SurvivalCovariateTermBlockTemplate::Static
2150    };
2151    let log_sigmaspec = if let Some(noise) = config.noise_formula.as_deref() {
2152        let noise_parsed = parse_formula(&format!("{} ~ {noise}", parsed.response))?;
2153        // Use the same aliased col_map as the main termspec — survival
2154        // marginal-slope reserves `z` as a placeholder for `--z-column`,
2155        // and the logslope/noise formula may reference it too.
2156        build_termspec_with_geometry(
2157            &noise_parsed.terms,
2158            data,
2159            termspec_col_map,
2160            &mut inference_notes,
2161            config.scale_dimensions,
2162            &policy,
2163        )?
2164    } else if survival_mode == SurvivalLikelihoodMode::LocationScale {
2165        termspec.clone()
2166    } else {
2167        TermCollectionSpec {
2168            linear_terms: vec![],
2169            random_effect_terms: vec![],
2170            smooth_terms: vec![],
2171        }
2172    };
2173    let marginal_z_column_name =
2174        if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2175            Some(config.z_column.as_deref().ok_or_else(|| {
2176                "marginal-slope survival requires z_column in FitConfig".to_string()
2177            })?)
2178        } else {
2179            None
2180        };
2181    let marginal_z = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2182        let _base_link = resolve_survival_marginal_slope_base_link(parsed.linkspec.as_ref())?;
2183        let z_col_name = marginal_z_column_name
2184            .expect("marginal-slope z column should be validated before materialization");
2185        let z_idx = resolve_role_col(col_map, z_col_name, "z")?;
2186        Some(data.values.column(z_idx).to_owned())
2187    } else {
2188        None
2189    };
2190    let (marginal_logslopespec, marginal_slope_deviation_routing) = if survival_mode
2191        == SurvivalLikelihoodMode::MarginalSlope
2192    {
2193        if let Some(ls_formula) = config.logslope_formula.as_deref() {
2194            let (_, ls_parsed) =
2195                parse_matching_auxiliary_formula(ls_formula, &parsed.response, "logslope_formula")?;
2196            if ls_parsed.linkspec.is_some() {
2197                return Err(
2198                        "link(...) is not supported in logslope_formula for the survival marginal-slope family"
2199                            .to_string(),
2200                    );
2201            }
2202            if ls_parsed.timewiggle.is_some() {
2203                return Err(
2204                        "timewiggle(...) is not supported in logslope_formula for the survival marginal-slope family"
2205                            .to_string(),
2206                    );
2207            }
2208            if ls_parsed.survivalspec.is_some() {
2209                return Err(
2210                        "survmodel(...) is not supported in logslope_formula for the survival marginal-slope family"
2211                            .to_string(),
2212                    );
2213            }
2214            validate_marginal_slope_z_column_exclusion(
2215                parsed,
2216                &ls_parsed,
2217                marginal_z_column_name.expect("marginal-slope z column should be available"),
2218                "survival marginal-slope",
2219                "logslope_formula",
2220            )?;
2221            (
2222                Some(build_termspec_with_geometry(
2223                    &ls_parsed.terms,
2224                    data,
2225                    marginal_slope_aliased_col_map
2226                        .as_ref()
2227                        .expect("marginal-slope column map should be available"),
2228                    &mut inference_notes,
2229                    config.scale_dimensions,
2230                    &policy,
2231                )?),
2232                route_marginal_slope_deviation_blocks(
2233                    parsed.linkwiggle.as_ref(),
2234                    ls_parsed.linkwiggle.as_ref(),
2235                )?,
2236            )
2237        } else {
2238            validate_marginal_slope_z_column_exclusion(
2239                parsed,
2240                parsed,
2241                marginal_z_column_name.expect("marginal-slope z column should be available"),
2242                "survival marginal-slope",
2243                "logslope_formula",
2244            )?;
2245            (
2246                Some(termspec.clone()),
2247                route_marginal_slope_deviation_blocks(parsed.linkwiggle.as_ref(), None)?,
2248            )
2249        }
2250    } else {
2251        (
2252            None,
2253            MarginalSlopeDeviationRouting {
2254                score_warp: None,
2255                link_dev: None,
2256            },
2257        )
2258    };
2259    let marginal_slope_score_warp = marginal_slope_deviation_routing.score_warp;
2260    let marginal_slope_link_dev = marginal_slope_deviation_routing.link_dev;
2261    if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2262        if parsed.linkwiggle.is_some() {
2263            inference_notes.push(
2264                "survival marginal-slope routes formula-level linkwiggle(...) into its anchored internal link-deviation block while keeping the probit survival base link".to_string(),
2265            );
2266        }
2267        if marginal_slope_score_warp.is_some() {
2268            inference_notes.push(
2269                "survival marginal-slope routes logslope_formula linkwiggle(...) into its anchored internal score-warp block while keeping the probit survival base link".to_string(),
2270            );
2271        }
2272        if marginal_slope_link_dev.is_none() && marginal_slope_score_warp.is_none() {
2273            inference_notes.push(
2274                "survival marginal-slope rigid mode is algebraic closed-form exact".to_string(),
2275            );
2276        } else {
2277            inference_notes.push(
2278                "survival marginal-slope flexible score/link mode uses calibrated de-nested cubic transport cells with analytic value evaluation and calibrated survival normalization"
2279                    .to_string(),
2280            );
2281        }
2282    }
2283    let marginal_slope_frailty = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
2284        Some(fixed_gaussian_shift_frailty_from_spec(
2285            config.frailty.as_ref().unwrap_or(&FrailtySpec::None),
2286            "survival marginal-slope",
2287        )?)
2288    } else {
2289        None
2290    };
2291    match survival_mode {
2292        SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
2293            if config.frailty.is_some() =>
2294        {
2295            return Err(
2296                "frailty is not supported for transformation/weibull survival models".to_string(),
2297            );
2298        }
2299        SurvivalLikelihoodMode::LocationScale if config.frailty.is_some() => {
2300            return Err(
2301                "config.frailty is not implemented for survival-likelihood=location-scale"
2302                    .to_string(),
2303            );
2304        }
2305        SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
2306            if effective_timewiggle.is_some() =>
2307        {
2308            return Err(
2309                "timewiggle is not implemented for latent survival/binary likelihoods".to_string(),
2310            );
2311        }
2312        _ => {}
2313    }
2314    let latent_loading = if matches!(
2315        survival_mode,
2316        SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
2317    ) {
2318        let frailty = config.frailty.as_ref().unwrap_or(&FrailtySpec::None);
2319        Some(latent_hazard_loading(
2320            frailty,
2321            "workflow latent survival/binary",
2322        )?)
2323    } else {
2324        None
2325    };
2326
2327    let build_time_block =
2328        |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2329            let prepared = prepare_workflow_survival_time_stack(
2330                &age_entry,
2331                &age_exit,
2332                candidate,
2333                survival_mode,
2334                (survival_mode == SurvivalLikelihoodMode::LocationScale)
2335                    .then_some(&survival_inverse_link),
2336                time_anchor,
2337                exact_derivative_guard,
2338                &time_build,
2339                effective_timewiggle.as_ref(),
2340                None,
2341            )?;
2342            let time_p = prepared.time_design_exit.ncols();
2343            let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
2344                None
2345            } else {
2346                Some(Array1::from_elem(
2347                    prepared.time_penalties.len(),
2348                    config.time_smooth_lambda.ln(),
2349                ))
2350            };
2351            let time_block = TimeBlockInput {
2352                design_entry: prepared.time_design_entry.clone(),
2353                design_exit: prepared.time_design_exit.clone(),
2354                design_derivative_exit: prepared.time_design_derivative.clone(),
2355                offset_entry: prepared.eta_offset_entry.clone(),
2356                offset_exit: prepared.eta_offset_exit.clone(),
2357                derivative_offset_exit: prepared.derivative_offset_exit.clone(),
2358                structural_monotonicity: true,
2359                penalties: prepared.time_penalties.clone(),
2360                nullspace_dims: prepared.time_nullspace_dims.clone(),
2361                initial_log_lambdas: time_initial_log_lambdas,
2362                initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
2363            };
2364            Ok::<_, String>((prepared, time_block))
2365        };
2366
2367    let build_location_scale_request =
2368        |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2369            let (prepared, time_block) = build_time_block(candidate)?;
2370            let spec = SurvivalLocationScaleTermSpec {
2371                age_entry: age_entry.clone(),
2372                age_exit: age_exit.clone(),
2373                event_target: event.clone(),
2374                weights: weights.clone(),
2375                inverse_link: survival_inverse_link.clone(),
2376                derivative_guard: exact_derivative_guard,
2377                max_iter: 200,
2378                tol: 1e-7,
2379                time_block,
2380                thresholdspec: termspec.clone(),
2381                log_sigmaspec: log_sigmaspec.clone(),
2382                threshold_offset: threshold_offset.clone(),
2383                log_sigma_offset: log_sigma_offset.clone(),
2384                threshold_template: threshold_template.clone(),
2385                log_sigma_template: log_sigma_template.clone(),
2386                timewiggle_block: prepared.timewiggle_block,
2387                linkwiggle_block: None,
2388            };
2389            let optimize_inverse_link =
2390                survival_inverse_link_has_free_parameters(&spec.inverse_link);
2391            Ok::<_, String>(SurvivalLocationScaleFitRequest {
2392                data: data.values.view(),
2393                spec,
2394                wiggle: effective_linkwiggle_cfg.clone(),
2395                kappa_options: SpatialLengthScaleOptimizationOptions::default(),
2396                optimize_inverse_link,
2397            })
2398        };
2399
2400    let build_marginal_slope_request =
2401        |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2402            let (prepared, time_block) = build_time_block(candidate)?;
2403            Ok::<_, String>(SurvivalMarginalSlopeFitRequest {
2404                data: data.values.view(),
2405                spec: SurvivalMarginalSlopeTermSpec {
2406                    age_entry: age_entry.clone(),
2407                    age_exit: age_exit.clone(),
2408                    event_target: event.clone(),
2409                    weights: weights.clone(),
2410                    z: marginal_z.clone().ok_or_else(|| {
2411                        "marginal-slope survival requires z_column in FitConfig".to_string()
2412                    })?,
2413                    base_link: resolve_survival_marginal_slope_base_link(parsed.linkspec.as_ref())?,
2414                    marginalspec: termspec.clone(),
2415                    marginal_offset: threshold_offset.clone(),
2416                    frailty: marginal_slope_frailty.clone().ok_or_else(|| {
2417                        "internal error: marginal-slope frailty validation missing".to_string()
2418                    })?,
2419                    derivative_guard: exact_derivative_guard,
2420                    time_block,
2421                    timewiggle_block: prepared.timewiggle_block,
2422                    logslopespec: marginal_logslopespec.clone().ok_or_else(|| {
2423                        "marginal-slope survival is missing logslope spec".to_string()
2424                    })?,
2425                    logslope_offset: log_sigma_offset.clone(),
2426                    score_warp: marginal_slope_score_warp.clone(),
2427                    link_dev: marginal_slope_link_dev.clone(),
2428                    latent_z_policy: Default::default(),
2429                },
2430                options: BlockwiseFitOptions {
2431                    compute_covariance: false,
2432                    ..Default::default()
2433                },
2434                kappa_options: SpatialLengthScaleOptimizationOptions::default(),
2435            })
2436        };
2437
2438    let build_latent_survival_request =
2439        |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2440            let loading = latent_loading.ok_or_else(|| {
2441                "internal error: latent survival loading missing after frailty validation"
2442                    .to_string()
2443            })?;
2444            let prepared = prepare_workflow_survival_time_stack(
2445                &age_entry,
2446                &age_exit,
2447                candidate,
2448                survival_mode,
2449                None,
2450                time_anchor,
2451                exact_derivative_guard,
2452                &time_build,
2453                None,
2454                Some(loading),
2455            )?;
2456            let time_p = prepared.time_design_exit.ncols();
2457            let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
2458                None
2459            } else {
2460                Some(Array1::from_elem(
2461                    prepared.time_penalties.len(),
2462                    config.time_smooth_lambda.ln(),
2463                ))
2464            };
2465            let time_block = TimeBlockInput {
2466                design_entry: prepared.time_design_entry.clone(),
2467                design_exit: prepared.time_design_exit.clone(),
2468                design_derivative_exit: prepared.time_design_derivative.clone(),
2469                offset_entry: prepared.eta_offset_entry.clone(),
2470                offset_exit: prepared.eta_offset_exit.clone(),
2471                derivative_offset_exit: prepared.derivative_offset_exit.clone(),
2472                structural_monotonicity: true,
2473                penalties: prepared.time_penalties.clone(),
2474                nullspace_dims: prepared.time_nullspace_dims.clone(),
2475                initial_log_lambdas: time_initial_log_lambdas,
2476                initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
2477            };
2478            Ok::<_, String>(LatentSurvivalFitRequest {
2479                data: data.values.view(),
2480                spec: LatentSurvivalTermSpec {
2481                    age_entry: age_entry.clone(),
2482                    age_exit: age_exit.clone(),
2483                    event_target: event.mapv(|v| if v >= 0.5 { 1 } else { 0 }),
2484                    weights: weights.clone(),
2485                    derivative_guard: exact_derivative_guard,
2486                    time_block,
2487                    unloaded_mass_entry: prepared.unloaded_mass_entry,
2488                    unloaded_mass_exit: prepared.unloaded_mass_exit,
2489                    unloaded_hazard_exit: prepared.unloaded_hazard_exit,
2490                    meanspec: termspec.clone(),
2491                    mean_offset: threshold_offset.clone(),
2492                },
2493                frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
2494                options: BlockwiseFitOptions::default(),
2495            })
2496        };
2497
2498    let build_latent_binary_request =
2499        |candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
2500            let loading = latent_loading.ok_or_else(|| {
2501                "internal error: latent binary loading missing after frailty validation".to_string()
2502            })?;
2503            let prepared = prepare_workflow_survival_time_stack(
2504                &age_entry,
2505                &age_exit,
2506                candidate,
2507                survival_mode,
2508                None,
2509                time_anchor,
2510                exact_derivative_guard,
2511                &time_build,
2512                None,
2513                Some(loading),
2514            )?;
2515            let time_p = prepared.time_design_exit.ncols();
2516            let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
2517                None
2518            } else {
2519                Some(Array1::from_elem(
2520                    prepared.time_penalties.len(),
2521                    config.time_smooth_lambda.ln(),
2522                ))
2523            };
2524            let time_block = TimeBlockInput {
2525                design_entry: prepared.time_design_entry.clone(),
2526                design_exit: prepared.time_design_exit.clone(),
2527                design_derivative_exit: prepared.time_design_derivative.clone(),
2528                offset_entry: prepared.eta_offset_entry.clone(),
2529                offset_exit: prepared.eta_offset_exit.clone(),
2530                derivative_offset_exit: prepared.derivative_offset_exit.clone(),
2531                structural_monotonicity: true,
2532                penalties: prepared.time_penalties.clone(),
2533                nullspace_dims: prepared.time_nullspace_dims.clone(),
2534                initial_log_lambdas: time_initial_log_lambdas,
2535                initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
2536            };
2537            Ok::<_, String>(LatentBinaryFitRequest {
2538                data: data.values.view(),
2539                spec: LatentBinaryTermSpec {
2540                    age_entry: age_entry.clone(),
2541                    age_exit: age_exit.clone(),
2542                    event_target: event.mapv(|v| if v >= 0.5 { 1 } else { 0 }),
2543                    weights: weights.clone(),
2544                    derivative_guard: exact_derivative_guard,
2545                    time_block,
2546                    unloaded_mass_entry: prepared.unloaded_mass_entry,
2547                    unloaded_mass_exit: prepared.unloaded_mass_exit,
2548                    meanspec: termspec.clone(),
2549                    mean_offset: threshold_offset.clone(),
2550                },
2551                frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
2552                options: BlockwiseFitOptions::default(),
2553            })
2554        };
2555
2556    let baseline_cfg = if matches!(
2557        survival_mode,
2558        SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
2559    ) {
2560        baseline_cfg
2561    } else if baseline_cfg.target != SurvivalBaselineTarget::Linear
2562        && survival_mode == SurvivalLikelihoodMode::MarginalSlope
2563    {
2564        optimize_survival_baseline_config_with_gradient(
2565            &baseline_cfg,
2566            "workflow survival marginal-slope baseline",
2567            |candidate| {
2568                let fit =
2569                    fit_survival_marginal_slope_model(build_marginal_slope_request(candidate)?)
2570                        .map_err(|e| format!("survival marginal-slope fit failed: {e}"))?;
2571                let gradient = marginal_slope_baseline_chain_rule_gradient(
2572                    age_entry.view(),
2573                    age_exit.view(),
2574                    candidate,
2575                    &fit.baseline_offset_residuals,
2576                )?
2577                .ok_or_else(|| {
2578                    "workflow survival marginal-slope baseline unexpectedly has no theta gradient"
2579                        .to_string()
2580                })?;
2581                let hessian = marginal_slope_baseline_chain_rule_hessian(
2582                    age_entry.view(),
2583                    age_exit.view(),
2584                    candidate,
2585                    &fit.baseline_offset_residuals,
2586                    &fit.baseline_offset_curvatures,
2587                )?
2588                .ok_or_else(|| {
2589                    "workflow survival marginal-slope baseline unexpectedly has no theta Hessian"
2590                        .to_string()
2591                })?;
2592                Ok((fit.fit.reml_score, gradient, hessian))
2593            },
2594        )?
2595    } else if baseline_cfg.target != SurvivalBaselineTarget::Linear {
2596        optimize_survival_baseline_config(
2597            &baseline_cfg,
2598            "workflow survival baseline",
2599            |candidate| match survival_mode {
2600                SurvivalLikelihoodMode::LocationScale => Ok(fit_survival_location_scale_model(
2601                    build_location_scale_request(candidate)?,
2602                )
2603                .map_err(|e| format!("survival location-scale fit failed: {e}"))?
2604                .fit
2605                .fit
2606                .reml_score),
2607                SurvivalLikelihoodMode::MarginalSlope => unreachable!(
2608                    "marginal-slope baseline profiling uses analytic GM-probit gradient"
2609                ),
2610                SurvivalLikelihoodMode::Latent => Ok(fit_latent_survival_model(
2611                    build_latent_survival_request(candidate)?,
2612                )
2613                .map_err(|e| format!("latent survival fit failed: {e}"))?
2614                .fit
2615                .reml_score),
2616                SurvivalLikelihoodMode::LatentBinary => Ok(fit_latent_binary_model(
2617                    build_latent_binary_request(candidate)?,
2618                )
2619                .map_err(|e| format!("latent binary fit failed: {e}"))?
2620                .fit
2621                .reml_score),
2622                SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => {
2623                    unreachable!()
2624                }
2625            },
2626        )?
2627    } else {
2628        baseline_cfg
2629    };
2630
2631    let request = match survival_mode {
2632        SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => {
2633            if config.noise_offset_column.is_some() {
2634                return Err(
2635                    "noise_offset_column is supported only for survival location-scale or marginal-slope"
2636                        .to_string(),
2637                );
2638            }
2639            let weibull_seed = if survival_mode == SurvivalLikelihoodMode::Weibull
2640                && effective_timewiggle.is_none()
2641            {
2642                let scale = config
2643                    .baseline_scale
2644                    .unwrap_or_else(|| positive_survival_time_seed(&age_exit));
2645                let shape = config.baseline_shape.unwrap_or(1.0);
2646                if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
2647                    return Err(
2648                        "weibull survival fit requires finite positive baseline_scale and baseline_shape"
2649                            .to_string(),
2650                    );
2651                }
2652                Some((scale, shape))
2653            } else {
2654                None
2655            };
2656            FitRequest::SurvivalTransformation(SurvivalTransformationFitRequest {
2657                data: data.values.view(),
2658                spec: SurvivalTransformationTermSpec {
2659                    age_entry: age_entry.clone(),
2660                    age_exit: age_exit.clone(),
2661                    event_target: event.mapv(|value| if value >= 0.5 { 1 } else { 0 }),
2662                    weights: weights.clone(),
2663                    covariate_spec: termspec.clone(),
2664                    covariate_offset: threshold_offset.clone(),
2665                    baseline_cfg,
2666                    likelihood_mode: survival_mode,
2667                    time_anchor,
2668                    time_build: time_build.clone(),
2669                    timewiggle: effective_timewiggle.clone(),
2670                    weibull_seed,
2671                    ridge_lambda: config.ridge_lambda,
2672                },
2673            })
2674        }
2675        SurvivalLikelihoodMode::LocationScale => {
2676            FitRequest::SurvivalLocationScale(build_location_scale_request(&baseline_cfg)?)
2677        }
2678        SurvivalLikelihoodMode::MarginalSlope => {
2679            FitRequest::SurvivalMarginalSlope(build_marginal_slope_request(&baseline_cfg)?)
2680        }
2681        SurvivalLikelihoodMode::Latent => {
2682            FitRequest::LatentSurvival(build_latent_survival_request(&baseline_cfg)?)
2683        }
2684        SurvivalLikelihoodMode::LatentBinary => {
2685            FitRequest::LatentBinary(build_latent_binary_request(&baseline_cfg)?)
2686        }
2687    };
2688
2689    Ok(MaterializedModel {
2690        request,
2691        inference_notes,
2692    })
2693}
2694
2695fn materialize_transformation_normal<'a>(
2696    parsed: &ParsedFormula,
2697    data: &'a Dataset,
2698    col_map: &HashMap<String, usize>,
2699    config: &FitConfig,
2700) -> Result<MaterializedModel<'a>, String> {
2701    if parsed.linkspec.is_some() {
2702        return Err("link(...) is not supported for the transformation-normal family".to_string());
2703    }
2704    if parsed.linkwiggle.is_some() {
2705        return Err(
2706            "linkwiggle(...) is not supported for the transformation-normal family".to_string(),
2707        );
2708    }
2709    if config.noise_offset_column.is_some() {
2710        return Err(
2711            "noise_offset_column is not supported for transformation-normal models".to_string(),
2712        );
2713    }
2714    if config.frailty.is_some() {
2715        return Err("frailty is not supported for transformation-normal models".to_string());
2716    }
2717
2718    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
2719    let y = data.values.column(y_col).to_owned();
2720    let mut inference_notes = Vec::new();
2721
2722    let policy = resolved_resource_policy(config);
2723    let mut covariate_spec =
2724        build_termspec(&parsed.terms, data, col_map, &mut inference_notes, &policy)?;
2725    if config.scale_dimensions {
2726        enable_scale_dimensions(&mut covariate_spec);
2727    }
2728
2729    let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
2730    let offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
2731
2732    Ok(MaterializedModel {
2733        request: FitRequest::TransformationNormal(TransformationNormalFitRequest {
2734            data: data.values.view(),
2735            response: y,
2736            weights,
2737            offset,
2738            covariate_spec,
2739            config: TransformationNormalConfig::default(),
2740            options: BlockwiseFitOptions::default(),
2741            kappa_options: SpatialLengthScaleOptimizationOptions::default(),
2742            warm_start: None,
2743        }),
2744        inference_notes,
2745    })
2746}
2747
2748fn materialize_location_scale<'a>(
2749    parsed: &ParsedFormula,
2750    data: &'a Dataset,
2751    col_map: &HashMap<String, usize>,
2752    config: &FitConfig,
2753) -> Result<MaterializedModel<'a>, String> {
2754    let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
2755    let y = data.values.column(y_col).to_owned();
2756    let mut inference_notes = Vec::new();
2757
2758    let noise_formula = config
2759        .noise_formula
2760        .as_deref()
2761        .ok_or_else(|| "noise_formula is required for location-scale models".to_string())?;
2762    let noise_parsed = parse_formula(&format!("{} ~ {noise_formula}", parsed.response))?;
2763
2764    let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
2765    let family = resolve_family(config.family.as_deref(), link_choice.as_ref(), y.view())?;
2766
2767    let effective_linkwiggle =
2768        effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
2769
2770    let policy = resolved_resource_policy(config);
2771    let mut meanspec = build_termspec(&parsed.terms, data, col_map, &mut inference_notes, &policy)?;
2772    let mut log_sigmaspec = build_termspec(
2773        &noise_parsed.terms,
2774        data,
2775        col_map,
2776        &mut inference_notes,
2777        &policy,
2778    )?;
2779    if config.scale_dimensions {
2780        enable_scale_dimensions(&mut meanspec);
2781        enable_scale_dimensions(&mut log_sigmaspec);
2782    }
2783
2784    let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
2785    let mean_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
2786    let noise_offset = resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
2787    let kappa_options = SpatialLengthScaleOptimizationOptions::default();
2788    let options = BlockwiseFitOptions::default();
2789
2790    let wiggle_cfg = effective_linkwiggle.map(|cfg| LinkWiggleConfig {
2791        degree: cfg.degree,
2792        num_internal_knots: cfg.num_internal_knots,
2793        penalty_orders: cfg.penalty_orders,
2794        double_penalty: cfg.double_penalty,
2795    });
2796
2797    if matches!(family, LikelihoodFamily::BinomialLatentCLogLog) {
2798        return Err(
2799            "latent-cloglog-binomial is not implemented for location-scale fitting".to_string(),
2800        );
2801    }
2802
2803    if is_binomial_family(family) {
2804        let link_kind = link_choice
2805            .as_ref()
2806            .map(|c| InverseLink::Standard(c.link))
2807            .unwrap_or(InverseLink::Standard(LinkFunction::Logit));
2808        Ok(MaterializedModel {
2809            request: FitRequest::BinomialLocationScale(BinomialLocationScaleFitRequest {
2810                data: data.values.view(),
2811                spec: BinomialLocationScaleTermSpec {
2812                    y,
2813                    weights,
2814                    link_kind,
2815                    thresholdspec: meanspec,
2816                    log_sigmaspec,
2817                    threshold_offset: mean_offset,
2818                    log_sigma_offset: noise_offset,
2819                },
2820                wiggle: wiggle_cfg,
2821                options,
2822                kappa_options,
2823            }),
2824            inference_notes,
2825        })
2826    } else {
2827        Ok(MaterializedModel {
2828            request: FitRequest::GaussianLocationScale(GaussianLocationScaleFitRequest {
2829                data: data.values.view(),
2830                spec: GaussianLocationScaleTermSpec {
2831                    y,
2832                    weights,
2833                    meanspec,
2834                    log_sigmaspec,
2835                    mean_offset,
2836                    log_sigma_offset: noise_offset,
2837                },
2838                wiggle: wiggle_cfg,
2839                options,
2840                kappa_options,
2841            }),
2842            inference_notes,
2843        })
2844    }
2845}
2846
2847#[cfg(test)]
2848mod tests {
2849    use super::*;
2850    use crate::basis::{DuchonNullspaceOrder, minimum_duchon_power_for_operator_penalties};
2851    use crate::inference::data::load_dataset_projected;
2852    use crate::inference::formula_dsl::{
2853        default_linkwiggle_formulaspec, parse_linkwiggle_formulaspec,
2854    };
2855    use crate::inference::model::{ColumnKindTag, DataSchema, SchemaColumn};
2856    use crate::solver::outer_strategy::{HessianSource, OuterPlan, OuterResult, Solver};
2857    use ndarray::Array2;
2858    use std::fs;
2859    use tempfile::tempdir;
2860
2861    fn load_survival_dataset() -> crate::inference::data::EncodedDataset {
2862        let td = tempdir().expect("tempdir");
2863        let data_path = td.path().join("survival.csv");
2864        fs::write(
2865            &data_path,
2866            "entry,exit,event,x,z\n0.0,1.0,1,0.2,-0.4\n0.3,1.6,0,-0.1,0.6\n",
2867        )
2868        .expect("write survival csv");
2869        load_dataset_projected(
2870            &data_path,
2871            &[
2872                "entry".to_string(),
2873                "exit".to_string(),
2874                "event".to_string(),
2875                "x".to_string(),
2876                "z".to_string(),
2877            ],
2878        )
2879        .expect("load survival dataset")
2880    }
2881
2882    #[test]
2883    fn survival_marginal_slope_materialize_rejects_z_column_in_main_formula() {
2884        let data = load_survival_dataset();
2885        let mut config = FitConfig::default();
2886        config.survival_likelihood = "marginal-slope".to_string();
2887        config.logslope_formula = Some("1".to_string());
2888        config.z_column = Some("z".to_string());
2889
2890        let err = materialize("Surv(entry, exit, event) ~ x + z", &data, &config)
2891            .err()
2892            .expect("main formula should reject z-column reuse");
2893
2894        assert!(err.contains("survival marginal-slope reserves z column 'z'"));
2895        assert!(err.contains("main formula"));
2896    }
2897
2898    #[test]
2899    fn survival_marginal_slope_materialize_rejects_z_column_in_logslope_formula() {
2900        let data = load_survival_dataset();
2901        let mut config = FitConfig::default();
2902        config.survival_likelihood = "marginal-slope".to_string();
2903        config.logslope_formula = Some("1 + z".to_string());
2904        config.z_column = Some("z".to_string());
2905
2906        let err = materialize("Surv(entry, exit, event) ~ x", &data, &config)
2907            .err()
2908            .expect("logslope formula should reject z-column reuse");
2909
2910        assert!(err.contains("survival marginal-slope reserves z column 'z'"));
2911        assert!(err.contains("logslope_formula"));
2912    }
2913
2914    #[test]
2915    fn survival_marginal_slope_materialize_rejects_z_column_when_logslope_defaults_to_main_spec() {
2916        let data = load_survival_dataset();
2917        let mut config = FitConfig::default();
2918        config.survival_likelihood = "marginal-slope".to_string();
2919        config.z_column = Some("z".to_string());
2920
2921        let err = materialize("Surv(entry, exit, event) ~ x + z", &data, &config)
2922            .err()
2923            .expect("defaulted logslope spec should still reject z-column reuse");
2924
2925        assert!(err.contains("survival marginal-slope reserves z column 'z'"));
2926        assert!(err.contains("main formula"));
2927    }
2928
2929    fn workflow_test_dataset() -> Dataset {
2930        Dataset {
2931            headers: vec![
2932                "age_entry".to_string(),
2933                "age_exit".to_string(),
2934                "event".to_string(),
2935                "bmi".to_string(),
2936                "z".to_string(),
2937            ],
2938            values: Array2::from_shape_vec(
2939                (4, 5),
2940                vec![
2941                    40.0, 43.0, 1.0, 22.0, -1.0, 41.0, 46.0, 0.0, 24.0, -0.2, 42.0, 47.0, 1.0,
2942                    27.0, 0.3, 44.0, 49.0, 0.0, 29.0, 1.2,
2943                ],
2944            )
2945            .expect("workflow test data shape"),
2946            schema: DataSchema {
2947                columns: vec![
2948                    SchemaColumn {
2949                        name: "age_entry".to_string(),
2950                        kind: ColumnKindTag::Continuous,
2951                        levels: vec![],
2952                    },
2953                    SchemaColumn {
2954                        name: "age_exit".to_string(),
2955                        kind: ColumnKindTag::Continuous,
2956                        levels: vec![],
2957                    },
2958                    SchemaColumn {
2959                        name: "event".to_string(),
2960                        kind: ColumnKindTag::Binary,
2961                        levels: vec![],
2962                    },
2963                    SchemaColumn {
2964                        name: "bmi".to_string(),
2965                        kind: ColumnKindTag::Continuous,
2966                        levels: vec![],
2967                    },
2968                    SchemaColumn {
2969                        name: "z".to_string(),
2970                        kind: ColumnKindTag::Continuous,
2971                        levels: vec![],
2972                    },
2973                ],
2974            },
2975            column_kinds: vec![
2976                ColumnKindTag::Continuous,
2977                ColumnKindTag::Continuous,
2978                ColumnKindTag::Binary,
2979                ColumnKindTag::Continuous,
2980                ColumnKindTag::Continuous,
2981            ],
2982        }
2983    }
2984
2985    fn workflow_test_outer_result(converged: bool, rho: Array1<f64>) -> OuterResult {
2986        OuterResult {
2987            rho,
2988            final_value: 1.25,
2989            iterations: 7,
2990            final_grad_norm: 0.5,
2991            final_gradient: None,
2992            final_hessian: None,
2993            converged,
2994            plan_used: OuterPlan {
2995                solver: Solver::Bfgs,
2996                hessian_source: HessianSource::BfgsApprox,
2997            },
2998            operator_trust_radius: None,
2999            operator_stop_reason: None,
3000        }
3001    }
3002
3003    #[test]
3004    fn workflow_survival_marginal_slope_routes_logslope_linkwiggle_into_score_warp_only() {
3005        let data = workflow_test_dataset();
3006        let config = FitConfig {
3007            survival_likelihood: "marginal-slope".to_string(),
3008            logslope_formula: Some(
3009                "1 + linkwiggle(degree=5, internal_knots=7, penalty_order=\"2,3\")".to_string(),
3010            ),
3011            z_column: Some("z".to_string()),
3012            ..FitConfig::default()
3013        };
3014        let materialized = materialize(
3015            "Surv(age_entry, age_exit, event) ~ s(bmi) + linkwiggle(degree=4, internal_knots=9, penalty_order=\"1\")",
3016            &data,
3017            &config,
3018        )
3019        .expect("workflow materialization should succeed");
3020
3021        let MaterializedModel {
3022            request,
3023            inference_notes,
3024        } = materialized;
3025        let FitRequest::SurvivalMarginalSlope(request) = request else {
3026            panic!("expected survival marginal-slope request");
3027        };
3028
3029        let link_dev = request.spec.link_dev.expect("main-formula link-dev");
3030        let score_warp = request.spec.score_warp.expect("logslope score-warp");
3031        assert_eq!(link_dev.degree, 4);
3032        assert_eq!(link_dev.num_internal_knots, 9);
3033        assert_eq!(link_dev.penalty_order, 1);
3034        assert_eq!(link_dev.penalty_orders, vec![1]);
3035        assert_eq!(score_warp.degree, 5);
3036        assert_eq!(score_warp.num_internal_knots, 7);
3037        assert_eq!(score_warp.penalty_order, 3);
3038        assert_eq!(score_warp.penalty_orders, vec![2, 3]);
3039        assert!(
3040            inference_notes
3041                .iter()
3042                .any(|note| note.contains("link-deviation block")),
3043            "workflow notes should mention main-formula linkwiggle routing"
3044        );
3045        assert!(
3046            inference_notes
3047                .iter()
3048                .any(|note| note.contains("score-warp block")),
3049            "workflow notes should mention logslope_formula linkwiggle routing"
3050        );
3051    }
3052
3053    #[test]
3054    fn materialize_routes_bernoulli_marginal_slope_when_logslope_and_z_are_set() {
3055        let data = workflow_test_dataset();
3056        let config = FitConfig {
3057            logslope_formula: Some("1".to_string()),
3058            z_column: Some("z".to_string()),
3059            ..FitConfig::default()
3060        };
3061        let materialized = materialize("event ~ bmi", &data, &config)
3062            .expect("Bernoulli marginal-slope materialization should succeed");
3063        assert!(matches!(
3064            materialized.request,
3065            FitRequest::BernoulliMarginalSlope(_)
3066        ));
3067    }
3068
3069    #[test]
3070    fn linkwiggle_defaults_are_consistent_across_formula_and_runtime() {
3071        let parsed = parse_linkwiggle_formulaspec(&Default::default(), "linkwiggle()")
3072            .expect("default linkwiggle should parse");
3073        let formula_default = default_linkwiggle_formulaspec();
3074        let runtime_default = DeviationBlockConfig::default();
3075        assert_eq!(parsed.degree, formula_default.degree);
3076        assert_eq!(
3077            parsed.num_internal_knots,
3078            formula_default.num_internal_knots
3079        );
3080        assert_eq!(parsed.penalty_orders, formula_default.penalty_orders);
3081        assert_eq!(parsed.double_penalty, formula_default.double_penalty);
3082        assert_eq!(runtime_default.degree, formula_default.degree);
3083        assert_eq!(
3084            runtime_default.num_internal_knots,
3085            formula_default.num_internal_knots
3086        );
3087        assert_eq!(
3088            runtime_default.penalty_orders,
3089            formula_default.penalty_orders
3090        );
3091        assert_eq!(
3092            runtime_default.double_penalty,
3093            formula_default.double_penalty
3094        );
3095    }
3096
3097    #[test]
3098    fn survival_marginal_slope_accepts_explicit_probit_link() {
3099        let data = workflow_test_dataset();
3100        let config = FitConfig {
3101            survival_likelihood: "marginal-slope".to_string(),
3102            logslope_formula: Some("1".to_string()),
3103            z_column: Some("z".to_string()),
3104            ..FitConfig::default()
3105        };
3106        let ok = materialize(
3107            "Surv(age_entry, age_exit, event) ~ bmi + link(type=probit)",
3108            &data,
3109            &config,
3110        );
3111        assert!(ok.is_ok(), "explicit probit should be accepted");
3112
3113        let err = match materialize(
3114            "Surv(age_entry, age_exit, event) ~ bmi + link(type=logit)",
3115            &data,
3116            &config,
3117        ) {
3118            Ok(_) => panic!("non-probit link should be rejected"),
3119            Err(err) => err,
3120        };
3121        assert!(err.contains("only link(type=probit)"));
3122    }
3123
3124    #[test]
3125    fn high_dimensional_duchon_default_power_is_admissible() {
3126        let dim = 16;
3127        let power = minimum_duchon_power_for_operator_penalties(dim, DuchonNullspaceOrder::Zero, 2);
3128        assert!(2 * (1 + power) > dim + 2);
3129    }
3130
3131    #[test]
3132    fn survival_location_scale_wiggle_rejects_unsupported_inverse_link() {
3133        let data = workflow_test_dataset();
3134        let materialized = materialize(
3135            "Surv(age_entry, age_exit, event) ~ bmi + linkwiggle(degree=4, internal_knots=3, penalty_order=\"1\")",
3136            &data,
3137            &FitConfig::default(),
3138        )
3139        .expect("workflow materialization should succeed");
3140
3141        let MaterializedModel { request, .. } = materialized;
3142        let FitRequest::SurvivalLocationScale(mut request) = request else {
3143            panic!("expected survival location-scale request");
3144        };
3145        request.spec.inverse_link = InverseLink::Sas(
3146            state_from_sasspec(SasLinkSpec {
3147                initial_epsilon: 0.1,
3148                initial_log_delta: 0.0,
3149            })
3150            .expect("valid SAS state"),
3151        );
3152        request.optimize_inverse_link = false;
3153
3154        let err = match fit_survival_location_scale_model(request) {
3155            Ok(_) => panic!("survival link wiggle should reject unsupported inverse links"),
3156            Err(e) => e,
3157        };
3158
3159        assert!(err.contains("survival link wiggle"));
3160        assert!(err.contains("does not support"));
3161    }
3162
3163    #[test]
3164    fn survival_inverse_link_result_requires_convergence() {
3165        let err = recover_converged_survival_inverse_link(
3166            workflow_test_outer_result(false, Array1::from_vec(vec![0.1, -0.2])),
3167            "survival inverse-link optimization (SAS, dim=2)",
3168            |_| Some(InverseLink::Standard(LinkFunction::Logit)),
3169        )
3170        .expect_err("non-converged inverse-link search should fail");
3171
3172        assert!(err.contains("did not converge"));
3173        assert!(err.contains("final_objective"));
3174    }
3175
3176    #[test]
3177    fn survival_inverse_link_result_requires_recoverable_state() {
3178        let err = recover_converged_survival_inverse_link(
3179            workflow_test_outer_result(true, Array1::from_vec(vec![9.0, 8.0])),
3180            "survival inverse-link optimization (mixture, dim=2)",
3181            |_| None,
3182        )
3183        .expect_err("unrecoverable inverse-link state should fail");
3184
3185        assert!(err.contains("produced an invalid inverse-link state"));
3186        assert!(err.contains("9.0"));
3187    }
3188}