1use gam_solve::estimate::UnifiedFitResult;
19use crate::bms::deviation_runtime::AnchorComponentTag;
20use crate::bms::{
21 DeviationRuntime, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
22};
23use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
24use crate::scale_design::ScaleDeviationTransform;
25use crate::survival::construction::{
26 SavedSurvivalTimeBasis, SurvivalBaselineConfig, survival_baseline_targetname,
27};
28use crate::survival::location_scale::{
29 ResidualDistribution, residual_distribution_from_inverse_link,
30};
31use crate::transformation_normal::TransformationNormalFamily;
32use crate::inference::model::{
33 FittedFamily, FittedModelPayload, MODEL_PAYLOAD_VERSION, ModelKind, SavedAnchorComponent,
34 SavedAnchorKind, SavedCompiledFlexBlock, SavedLatentZNormalization, SavedResidualCascade,
35 SavedSplineScan, TransformationScoreCalibration,
36};
37use gam_terms::smooth::TermCollectionSpec;
38use gam_problem::types::{
39 InverseLink, LikelihoodSpec, ResponseFamily, StandardLink, inverse_link_to_binomial_spec,
40};
41use gam_data::DataSchema;
42use ndarray::Array2;
43
44const FAMILY_BERNOULLI_MARGINAL_SLOPE: &str = "bernoulli-marginal-slope";
46
47const FAMILY_TRANSFORMATION_NORMAL: &str = "transformation-normal";
49
50pub fn serialize_anchored_deviation_runtime(runtime: &DeviationRuntime) -> SavedCompiledFlexBlock {
57 let mut anchor_correction: Option<Vec<Vec<f64>>> = None;
58 let mut anchor_components: Vec<SavedAnchorComponent> = Vec::new();
59 if let Some(installed) = runtime.installed_flex_block() {
60 anchor_correction = Some(
61 installed
62 .anchor_correction
63 .rows()
64 .into_iter()
65 .map(|row| row.to_vec())
66 .collect::<Vec<Vec<f64>>>(),
67 );
68 for component in &installed.anchor_components {
69 anchor_components.push(SavedAnchorComponent {
70 kind: match component {
71 AnchorComponentTag::Parametric { block, ncols } => {
72 SavedAnchorKind::Parametric {
73 block: *block,
74 ncols: *ncols,
75 }
76 }
77 AnchorComponentTag::FlexEvaluation { ncols } => {
78 SavedAnchorKind::FlexEvaluation { ncols: *ncols }
79 }
80 },
81 });
82 }
83 }
84 SavedCompiledFlexBlock {
85 kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
86 breakpoints: runtime.breakpoints().to_vec(),
87 basis_dim: runtime.basis_dim(),
88 span_c0: runtime
89 .span_c0()
90 .rows()
91 .into_iter()
92 .map(|row| row.to_vec())
93 .collect(),
94 span_c1: runtime
95 .span_c1()
96 .rows()
97 .into_iter()
98 .map(|row| row.to_vec())
99 .collect(),
100 span_c2: runtime
101 .span_c2()
102 .rows()
103 .into_iter()
104 .map(|row| row.to_vec())
105 .collect(),
106 span_c3: runtime
107 .span_c3()
108 .rows()
109 .into_iter()
110 .map(|row| row.to_vec())
111 .collect(),
112 anchor_correction,
113 anchor_components,
114 }
115}
116
117pub struct SavedModelSourceMetadata {
125 pub training_headers: Vec<String>,
126 pub training_feature_ranges: Option<Vec<(f64, f64)>>,
127 pub offset_column: Option<String>,
128 pub noise_offset_column: Option<String>,
129}
130
131impl SavedModelSourceMetadata {
132 fn apply_to(self, payload: &mut FittedModelPayload) {
133 match self.training_feature_ranges {
134 Some(ranges) => payload.set_training_feature_metadata(self.training_headers, ranges),
135 None => payload.training_headers = Some(self.training_headers),
136 }
137 payload.offset_column = self.offset_column;
138 payload.noise_offset_column = self.noise_offset_column;
139 }
140}
141
142pub struct BernoulliMarginalSlopeInputs<'a> {
149 pub formula: String,
150 pub data_schema: DataSchema,
151 pub logslope_formula: String,
152 pub z_column: String,
153 pub resolved_marginalspec: TermCollectionSpec,
154 pub resolved_logslopespec: TermCollectionSpec,
155 pub fit_result: UnifiedFitResult,
156 pub p_marginal: usize,
169 pub baseline_marginal: f64,
170 pub baseline_logslope: f64,
171 pub latent_z_normalization: SavedLatentZNormalization,
172 pub latent_measure: LatentMeasureKind,
173 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
174 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
175 pub score_warp_runtime: Option<&'a DeviationRuntime>,
176 pub link_dev_runtime: Option<&'a DeviationRuntime>,
177 pub base_link: InverseLink,
178 pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
179}
180
181fn truncate_marginal_slope_influence_absorber(
214 fit_result: UnifiedFitResult,
215 p_marginal: usize,
216) -> Result<UnifiedFitResult, String> {
217 let Some(block0) = fit_result.blocks.first() else {
218 return Err("marginal-slope fit result has no coefficient blocks".to_string());
219 };
220 let widened_len = block0.beta.len();
221 if widened_len <= p_marginal {
222 return Ok(fit_result);
224 }
225 let p_influence = widened_len - p_marginal;
226
227 let UnifiedFitResult {
228 mut blocks,
229 log_lambdas,
230 lambdas,
231 likelihood_family,
232 likelihood_scale,
233 log_likelihood_normalization,
234 log_likelihood,
235 deviance,
236 reml_score,
237 stable_penalty_term,
238 penalized_objective,
239 used_device,
240 outer_iterations,
241 outer_converged,
242 outer_gradient_norm,
243 standard_deviation,
244 covariance_conditional,
245 covariance_corrected,
246 inference,
247 fitted_link,
248 geometry: _,
249 mut block_states,
250 beta: _,
251 pirls_status,
252 max_abs_eta,
253 constraint_kkt,
254 artifacts,
255 inner_cycles,
256 outer_cost_evals: _,
257 } = fit_result;
258
259 blocks[0].beta = blocks[0].beta.slice(ndarray::s![..p_marginal]).to_owned();
262 if let Some(state0) = block_states.first_mut() {
263 state0.beta = state0.beta.slice(ndarray::s![..p_marginal]).to_owned();
264 }
265
266 let drop_gamma_block = |cov: Option<Array2<f64>>| -> Option<Array2<f64>> {
269 cov.map(|cov| {
270 let total = cov.nrows();
271 let kept: Vec<usize> = (0..p_marginal)
272 .chain((p_marginal + p_influence)..total)
273 .collect();
274 let mut out = Array2::<f64>::zeros((kept.len(), kept.len()));
275 for (ri, &r) in kept.iter().enumerate() {
276 for (ci, &c) in kept.iter().enumerate() {
277 out[[ri, ci]] = cov[[r, c]];
278 }
279 }
280 out
281 })
282 };
283 let covariance_conditional = drop_gamma_block(covariance_conditional);
284 let covariance_corrected = drop_gamma_block(covariance_corrected);
285
286 UnifiedFitResult::try_from_parts(gam_solve::estimate::UnifiedFitResultParts {
287 blocks,
288 log_lambdas,
289 lambdas,
290 likelihood_family,
291 likelihood_scale,
292 log_likelihood_normalization,
293 log_likelihood,
294 deviance,
295 reml_score,
296 stable_penalty_term,
297 penalized_objective,
298 used_device,
302 outer_iterations,
303 outer_converged,
304 outer_gradient_norm,
305 standard_deviation,
306 covariance_conditional,
307 covariance_corrected,
308 inference,
309 fitted_link,
310 geometry: None,
312 block_states,
313 pirls_status,
314 max_abs_eta,
315 constraint_kkt,
316 artifacts,
317 inner_cycles,
318 })
319 .map_err(|e| {
320 format!("marginal-slope influence-absorber truncation produced an invalid fit result: {e}")
321 })
322}
323
324pub fn assemble_spline_scan_payload(
330 formula: String,
331 feature_column: String,
332 fit: &gam_solve::spline_scan::SplineScanFit,
333 data_schema: DataSchema,
334 training_headers: Vec<String>,
335 training_feature_ranges: Vec<(f64, f64)>,
336) -> FittedModelPayload {
337 let mut payload = FittedModelPayload::new(
338 MODEL_PAYLOAD_VERSION,
339 formula,
340 ModelKind::Standard,
341 FittedFamily::Standard {
342 likelihood: LikelihoodSpec::gaussian_identity(),
343 link: None,
344 latent_cloglog_state: None,
345 mixture_state: None,
346 sas_state: None,
347 },
348 "gaussian".to_string(),
349 );
350 payload.spline_scan = Some(SavedSplineScan {
351 feature_column,
352 state: fit.to_state(),
353 });
354 payload.data_schema = Some(data_schema);
355 payload.set_training_feature_metadata(training_headers, training_feature_ranges);
356 payload
357}
358
359pub fn assemble_residual_cascade_payload(
365 formula: String,
366 feature_columns: Vec<String>,
367 fit: &gam_solve::residual_cascade::ResidualCascadeFit,
368 data_schema: DataSchema,
369 training_headers: Vec<String>,
370 training_feature_ranges: Vec<(f64, f64)>,
371) -> Result<FittedModelPayload, String> {
372 let mut payload = FittedModelPayload::new(
373 MODEL_PAYLOAD_VERSION,
374 formula,
375 ModelKind::Standard,
376 FittedFamily::Standard {
377 likelihood: gam_problem::types::LikelihoodSpec::gaussian_identity(),
378 link: None,
379 latent_cloglog_state: None,
380 mixture_state: None,
381 sas_state: None,
382 },
383 "gaussian".to_string(),
384 );
385 payload.residual_cascade = Some(SavedResidualCascade {
386 feature_columns,
387 state: fit.to_state().map_err(|e| {
388 format!("residual-cascade to_state failed during payload assembly: {e}")
389 })?,
390 });
391 payload.data_schema = Some(data_schema);
392 payload.set_training_feature_metadata(training_headers, training_feature_ranges);
393 Ok(payload)
394}
395
396pub fn assemble_bernoulli_marginal_slope_payload(
404 inputs: BernoulliMarginalSlopeInputs<'_>,
405 source: SavedModelSourceMetadata,
406) -> Result<FittedModelPayload, String> {
407 let BernoulliMarginalSlopeInputs {
408 formula,
409 data_schema,
410 logslope_formula,
411 z_column,
412 resolved_marginalspec,
413 resolved_logslopespec,
414 fit_result,
415 p_marginal,
416 baseline_marginal,
417 baseline_logslope,
418 latent_z_normalization,
419 latent_measure,
420 latent_z_rank_int_calibration,
421 latent_z_conditional_calibration,
422 score_warp_runtime,
423 link_dev_runtime,
424 base_link,
425 frailty,
426 } = inputs;
427
428 let fit_result = truncate_marginal_slope_influence_absorber(fit_result, p_marginal)?;
432
433 let marginal_likelihood_spec =
434 inverse_link_to_binomial_spec(&base_link).map_err(|e| e.to_string())?;
435
436 let mut payload = FittedModelPayload::new(
437 MODEL_PAYLOAD_VERSION,
438 formula,
439 ModelKind::MarginalSlope,
440 FittedFamily::MarginalSlope {
441 likelihood: marginal_likelihood_spec,
442 base_link: base_link.clone(),
443 frailty,
444 },
445 FAMILY_BERNOULLI_MARGINAL_SLOPE.to_string(),
446 );
447 payload.unified = Some(fit_result.clone());
448 payload.fit_result = Some(fit_result);
449 payload.data_schema = Some(data_schema);
450 payload.formula_logslope = Some(logslope_formula.clone());
451 payload.z_column = Some(z_column.clone());
452 payload.formula_logslopes = Some(vec![logslope_formula]);
453 payload.z_columns = Some(vec![z_column]);
454 payload.latent_z_normalization = Some(latent_z_normalization);
455 payload.latent_measure = Some(latent_measure);
456 payload.latent_z_rank_int_calibration = latent_z_rank_int_calibration;
457 payload.latent_z_conditional_calibration = latent_z_conditional_calibration;
458 payload.marginal_baseline = Some(baseline_marginal);
459 payload.logslope_baseline = Some(baseline_logslope);
460 payload.logslope_baselines = Some(vec![baseline_logslope]);
461 payload.link = Some(base_link);
462 payload.resolved_termspec = Some(resolved_marginalspec);
463 payload.resolved_termspec_logslopes = Some(vec![resolved_logslopespec.clone()]);
464 payload.resolved_termspec_logslope = Some(resolved_logslopespec);
465 payload.score_warp_runtime = score_warp_runtime.map(serialize_anchored_deviation_runtime);
466 payload.link_deviation_runtime = link_dev_runtime.map(serialize_anchored_deviation_runtime);
467 source.apply_to(&mut payload);
468 Ok(payload)
469}
470
471pub struct TransformationNormalInputs<'a> {
478 pub formula: String,
479 pub data_schema: DataSchema,
480 pub resolved_covariate_spec: TermCollectionSpec,
481 pub fit_result: UnifiedFitResult,
482 pub family: &'a TransformationNormalFamily,
483 pub score_calibration: TransformationScoreCalibration,
484}
485
486pub fn assemble_transformation_normal_payload(
492 inputs: TransformationNormalInputs<'_>,
493 source: SavedModelSourceMetadata,
494) -> FittedModelPayload {
495 let TransformationNormalInputs {
496 formula,
497 data_schema,
498 resolved_covariate_spec,
499 fit_result,
500 family,
501 score_calibration,
502 } = inputs;
503
504 let mut payload = FittedModelPayload::new(
505 MODEL_PAYLOAD_VERSION,
506 formula,
507 ModelKind::TransformationNormal,
508 FittedFamily::TransformationNormal {
509 likelihood: LikelihoodSpec::new(
510 ResponseFamily::Gaussian,
511 InverseLink::Standard(StandardLink::Identity),
512 ),
513 },
514 FAMILY_TRANSFORMATION_NORMAL.to_string(),
515 );
516 payload.unified = Some(fit_result.clone());
517 payload.fit_result = Some(fit_result);
518 payload.data_schema = Some(data_schema);
519 payload.resolved_termspec = Some(resolved_covariate_spec);
520 payload.transformation_response_knots = Some(family.response_knots().to_vec());
521 payload.transformation_response_transform = Some(
522 family
523 .response_transform()
524 .rows()
525 .into_iter()
526 .map(|row| row.to_vec())
527 .collect(),
528 );
529 payload.transformation_response_degree = Some(family.response_degree());
530 payload.transformation_response_median = Some(family.response_median());
531 payload.transformation_score_calibration = Some(score_calibration);
532 source.apply_to(&mut payload);
533 payload
534}
535
536pub enum LocationScaleResponse<'a> {
542 Gaussian {
545 response_scale: f64,
546 base_link: Option<InverseLink>,
547 },
548 Binomial {
550 link: InverseLink,
551 noise_transform: &'a ScaleDeviationTransform,
552 },
553 Dispersion {
559 likelihood: LikelihoodSpec,
560 base_link: InverseLink,
561 family_tag: &'static str,
562 },
563}
564
565pub struct LocationScaleWiggle {
570 pub knots: Vec<f64>,
571 pub degree: usize,
572 pub beta_link_wiggle: Vec<f64>,
573}
574
575pub struct LocationScaleInputs {
579 pub formula: String,
580 pub data_schema: DataSchema,
581 pub noise_formula: String,
582 pub resolved_termspec: TermCollectionSpec,
583 pub resolved_termspec_noise: TermCollectionSpec,
584 pub fit_result: UnifiedFitResult,
585 pub beta_noise: Option<Vec<f64>>,
586 pub wiggle: Option<LocationScaleWiggle>,
587}
588
589pub fn assemble_location_scale_payload(
594 inputs: LocationScaleInputs,
595 response: LocationScaleResponse<'_>,
596 source: SavedModelSourceMetadata,
597) -> Result<FittedModelPayload, String> {
598 let (family_tag, likelihood, base_link, link, response_scale, noise_transform) = match response
599 {
600 LocationScaleResponse::Gaussian {
601 response_scale,
602 base_link,
603 } => (
604 "gaussian-location-scale".to_string(),
605 LikelihoodSpec::gaussian_identity(),
606 None,
610 Some(base_link.unwrap_or(InverseLink::Standard(StandardLink::Identity))),
611 Some(response_scale),
612 None,
613 ),
614 LocationScaleResponse::Binomial {
615 link,
616 noise_transform,
617 } => {
618 let likelihood = inverse_link_to_binomial_spec(&link).map_err(|e| {
619 format!("failed to resolve LikelihoodSpec for binomial location-scale link {link:?}: {e}")
620 })?;
621 (
622 "binomial-location-scale".to_string(),
623 likelihood,
624 Some(link.clone()),
625 Some(link),
626 None,
627 Some(noise_transform),
628 )
629 }
630 LocationScaleResponse::Dispersion {
631 likelihood,
632 base_link,
633 family_tag,
634 } => (
635 family_tag.to_string(),
636 likelihood,
637 Some(base_link.clone()),
638 Some(base_link),
639 None,
640 None,
641 ),
642 };
643
644 let mut payload = FittedModelPayload::new(
645 MODEL_PAYLOAD_VERSION,
646 inputs.formula,
647 ModelKind::LocationScale,
648 FittedFamily::LocationScale {
649 likelihood,
650 base_link,
651 },
652 family_tag,
653 );
654 payload.unified = Some(inputs.fit_result.clone());
655 payload.fit_result = Some(inputs.fit_result);
656 payload.data_schema = Some(inputs.data_schema);
657 payload.link = link;
658 payload.formula_noise = Some(inputs.noise_formula);
659 payload.beta_noise = inputs.beta_noise;
660 payload.gaussian_response_scale = response_scale;
661 if let Some(transform) = noise_transform {
662 payload.noise_projection = Some(
663 transform
664 .projection_coef
665 .rows()
666 .into_iter()
667 .map(|row| row.to_vec())
668 .collect(),
669 );
670 payload.noise_center = Some(transform.weighted_column_mean.to_vec());
671 payload.noise_scale = Some(transform.rescale.to_vec());
672 payload.noise_non_intercept_start = Some(transform.non_intercept_start);
673 payload.noise_projection_ridge_alpha = Some(transform.projection_ridge_alpha);
674 }
675 payload.resolved_termspec = Some(inputs.resolved_termspec);
676 payload.resolved_termspec_noise = Some(inputs.resolved_termspec_noise);
677 if let Some(wiggle) = inputs.wiggle {
678 payload.linkwiggle_knots = Some(wiggle.knots);
679 payload.linkwiggle_degree = Some(wiggle.degree);
680 payload.beta_link_wiggle = Some(wiggle.beta_link_wiggle);
681 }
682 source.apply_to(&mut payload);
683 Ok(payload)
684}
685
686pub struct SurvivalMarginalSlopeInputs<'a> {
691 pub formula: String,
692 pub data_schema: DataSchema,
693 pub fit_result: UnifiedFitResult,
694 pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
695 pub survival_entry: Option<String>,
696 pub survival_exit: String,
697 pub survival_event: String,
698 pub survivalspec: String,
699 pub baseline_cfg: SurvivalBaselineConfig,
700 pub time_basis: SavedSurvivalTimeBasis,
701 pub ridge_lambda: f64,
702 pub survival_likelihood_label: String,
703 pub resolved_marginalspec: TermCollectionSpec,
704 pub resolved_logslopespec: TermCollectionSpec,
705 pub logslope_formula: String,
706 pub z_column: String,
707 pub latent_z_normalization: SavedLatentZNormalization,
708 pub baseline_logslope: f64,
709 pub score_warp_runtime: Option<&'a DeviationRuntime>,
710 pub link_dev_runtime: Option<&'a DeviationRuntime>,
711 pub influence_absorber_width: Option<usize>,
716}
717
718fn new_royston_parmar_survival_payload(
726 formula: String,
727 fit_result: UnifiedFitResult,
728 data_schema: DataSchema,
729 survival_likelihood_label: &str,
730 survival_distribution: Option<ResidualDistribution>,
731 frailty: crate::survival::lognormal_kernel::FrailtySpec,
732) -> FittedModelPayload {
733 let mut payload = FittedModelPayload::new(
734 MODEL_PAYLOAD_VERSION,
735 formula,
736 ModelKind::Survival,
737 FittedFamily::Survival {
738 likelihood: LikelihoodSpec::new(
739 ResponseFamily::RoystonParmar,
740 InverseLink::Standard(StandardLink::Identity),
741 ),
742 survival_likelihood: Some(survival_likelihood_label.to_string()),
743 survival_distribution,
744 frailty,
745 },
746 ResponseFamily::RoystonParmar.name().to_string(),
747 );
748 payload.unified = Some(fit_result.clone());
749 payload.fit_result = Some(fit_result);
750 payload.data_schema = Some(data_schema);
751 payload
752}
753
754pub fn assemble_survival_marginal_slope_payload(
757 inputs: SurvivalMarginalSlopeInputs<'_>,
758 source: SavedModelSourceMetadata,
759) -> FittedModelPayload {
760 let mut payload = new_royston_parmar_survival_payload(
761 inputs.formula,
762 inputs.fit_result,
763 inputs.data_schema,
764 &inputs.survival_likelihood_label,
765 Some(ResidualDistribution::Gaussian),
766 inputs.frailty,
767 );
768 payload.survival_entry = inputs.survival_entry;
769 payload.survival_exit = Some(inputs.survival_exit);
770 payload.survival_event = Some(inputs.survival_event);
771 payload.survivalspec = Some(inputs.survivalspec);
772 payload.survival_baseline_target =
773 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
774 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
775 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
776 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
777 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
778 payload.apply_survival_time_basis(&inputs.time_basis);
779 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
780 payload.survival_likelihood = Some(inputs.survival_likelihood_label);
781 payload.survival_distribution = Some(ResidualDistribution::Gaussian);
782 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
783 payload.resolved_termspec = Some(inputs.resolved_marginalspec);
784 payload.resolved_termspec_logslopes = Some(vec![inputs.resolved_logslopespec.clone()]);
785 payload.resolved_termspec_logslope = Some(inputs.resolved_logslopespec);
786 payload.formula_logslope = Some(inputs.logslope_formula.clone());
787 payload.formula_logslopes = Some(vec![inputs.logslope_formula]);
788 payload.z_column = Some(inputs.z_column.clone());
789 payload.z_columns = Some(vec![inputs.z_column]);
790 payload.latent_z_normalization = Some(inputs.latent_z_normalization);
791 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
792 payload.logslope_baseline = Some(inputs.baseline_logslope);
793 payload.logslope_baselines = Some(vec![inputs.baseline_logslope]);
794 payload.score_warp_runtime = inputs
795 .score_warp_runtime
796 .map(serialize_anchored_deviation_runtime);
797 payload.link_deviation_runtime = inputs
798 .link_dev_runtime
799 .map(serialize_anchored_deviation_runtime);
800 payload.influence_absorber_width = inputs.influence_absorber_width;
801 source.apply_to(&mut payload);
802 payload
803}
804
805pub enum SurvivalTimewiggleBeta {
808 Single(Vec<f64>),
809 ByCause(Vec<Vec<f64>>),
810}
811
812pub struct SurvivalTimewiggle {
814 pub degree: usize,
815 pub knots: Vec<f64>,
816 pub penalty_orders: Option<Vec<usize>>,
817 pub double_penalty: Option<bool>,
818 pub beta: SurvivalTimewiggleBeta,
819}
820
821pub struct SurvivalTransformationInputs {
824 pub formula: String,
825 pub data_schema: DataSchema,
826 pub fit_result: UnifiedFitResult,
827 pub survival_entry: Option<String>,
828 pub survival_exit: String,
829 pub survival_event: String,
830 pub survivalspec: String,
831 pub cause_count: Option<usize>,
834 pub baseline_cfg: SurvivalBaselineConfig,
835 pub time_basis: SavedSurvivalTimeBasis,
836 pub ridge_lambda: f64,
837 pub survival_likelihood_label: String,
838 pub resolved_termspec: TermCollectionSpec,
839 pub survival_beta_time: Option<Vec<f64>>,
841 pub timewiggle: Option<SurvivalTimewiggle>,
842}
843
844pub fn assemble_survival_transformation_payload(
847 inputs: SurvivalTransformationInputs,
848 source: SavedModelSourceMetadata,
849) -> FittedModelPayload {
850 let mut payload = new_royston_parmar_survival_payload(
851 inputs.formula,
852 inputs.fit_result,
853 inputs.data_schema,
854 &inputs.survival_likelihood_label,
855 None,
856 crate::survival::lognormal_kernel::FrailtySpec::None,
857 );
858 payload.survival_entry = inputs.survival_entry;
859 payload.survival_exit = Some(inputs.survival_exit);
860 payload.survival_event = Some(inputs.survival_event);
861 payload.survivalspec = Some(inputs.survivalspec);
862 if let Some(cause_count) = inputs.cause_count {
863 payload.survival_cause_count = Some(cause_count);
864 payload.survival_endpoint_names = Some(
865 (1..=cause_count)
866 .map(|idx| format!("cause_{idx}"))
867 .collect(),
868 );
869 }
870 payload.survival_baseline_target =
871 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
872 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
873 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
874 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
875 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
876 payload.apply_survival_time_basis(&inputs.time_basis);
877 if let Some(timewiggle) = inputs.timewiggle {
878 payload.baseline_timewiggle_degree = Some(timewiggle.degree);
879 payload.baseline_timewiggle_knots = Some(timewiggle.knots);
880 payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
881 payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
882 match timewiggle.beta {
883 SurvivalTimewiggleBeta::Single(beta) => {
884 payload.beta_baseline_timewiggle = Some(beta);
885 }
886 SurvivalTimewiggleBeta::ByCause(by_cause) => {
887 payload.beta_baseline_timewiggle_by_cause = Some(by_cause);
888 }
889 }
890 }
891 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
892 payload.survival_likelihood = Some(inputs.survival_likelihood_label);
893 payload.survival_beta_time = inputs.survival_beta_time;
894 payload.resolved_termspec = Some(inputs.resolved_termspec);
895 source.apply_to(&mut payload);
896 payload
897}
898
899pub struct SurvivalLocationScaleInputs<'a> {
904 pub formula: String,
905 pub data_schema: DataSchema,
906 pub fit_result: UnifiedFitResult,
909 pub fitted_inverse_link: InverseLink,
910 pub linkwiggle_degree: Option<usize>,
913 pub linkwiggle_knots: Option<Vec<f64>>,
914 pub beta_link_wiggle: Option<Vec<f64>>,
915 pub baseline_timewiggle: Option<SurvivalTimewiggle>,
916 pub survival_entry: Option<String>,
917 pub survival_exit: String,
918 pub survival_event: String,
919 pub survivalspec: String,
920 pub baseline_cfg: SurvivalBaselineConfig,
921 pub time_basis: SavedSurvivalTimeBasis,
922 pub ridge_lambda: f64,
923 pub survival_likelihood_label: String,
924 pub formula_noise: Option<String>,
925 pub survival_beta_time: Vec<f64>,
926 pub survival_beta_threshold: Vec<f64>,
927 pub survival_beta_log_sigma: Vec<f64>,
928 pub noise_transform: &'a ScaleDeviationTransform,
929 pub resolved_thresholdspec: TermCollectionSpec,
930 pub resolved_log_sigmaspec: TermCollectionSpec,
931}
932
933pub fn assemble_survival_location_scale_payload(
936 inputs: SurvivalLocationScaleInputs<'_>,
937 source: SavedModelSourceMetadata,
938) -> FittedModelPayload {
939 let survival_distribution =
940 residual_distribution_from_inverse_link(&inputs.fitted_inverse_link);
941 let mut payload = new_royston_parmar_survival_payload(
942 inputs.formula,
943 inputs.fit_result,
944 inputs.data_schema,
945 &inputs.survival_likelihood_label,
946 survival_distribution,
947 crate::survival::lognormal_kernel::FrailtySpec::None,
948 );
949 payload.link = Some(inputs.fitted_inverse_link);
950 payload.linkwiggle_degree = inputs.linkwiggle_degree;
951 payload.linkwiggle_knots = inputs.linkwiggle_knots;
952 payload.beta_link_wiggle = inputs.beta_link_wiggle;
953 if let Some(timewiggle) = inputs.baseline_timewiggle {
954 payload.baseline_timewiggle_degree = Some(timewiggle.degree);
955 payload.baseline_timewiggle_knots = Some(timewiggle.knots);
956 payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
957 payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
958 if let SurvivalTimewiggleBeta::Single(beta) = timewiggle.beta {
959 payload.beta_baseline_timewiggle = Some(beta);
960 }
961 }
962 payload.survival_entry = inputs.survival_entry;
963 payload.survival_exit = Some(inputs.survival_exit);
964 payload.survival_event = Some(inputs.survival_event);
965 payload.survivalspec = Some(inputs.survivalspec);
966 payload.survival_baseline_target =
967 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
968 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
969 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
970 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
971 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
972 payload.apply_survival_time_basis(&inputs.time_basis);
973 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
974 payload.survival_likelihood = Some(inputs.survival_likelihood_label);
975 payload.formula_noise = inputs.formula_noise;
976 payload.survival_beta_time = Some(inputs.survival_beta_time);
977 payload.survival_beta_threshold = Some(inputs.survival_beta_threshold);
978 payload.survival_beta_log_sigma = Some(inputs.survival_beta_log_sigma);
979 payload.survival_noise_projection = Some(
980 inputs
981 .noise_transform
982 .projection_coef
983 .rows()
984 .into_iter()
985 .map(|row| row.to_vec())
986 .collect(),
987 );
988 payload.survival_noise_center = Some(inputs.noise_transform.weighted_column_mean.to_vec());
989 payload.survival_noise_scale = Some(inputs.noise_transform.rescale.to_vec());
990 payload.survival_noise_non_intercept_start = Some(inputs.noise_transform.non_intercept_start);
991 payload.survival_noise_projection_ridge_alpha =
992 Some(inputs.noise_transform.projection_ridge_alpha);
993 payload.survival_distribution = survival_distribution;
994 payload.resolved_termspec = Some(inputs.resolved_thresholdspec);
995 payload.resolved_termspec_noise = Some(inputs.resolved_log_sigmaspec);
996 source.apply_to(&mut payload);
997 payload
998}
999
1000pub struct LatentWindowInputs {
1004 pub formula: String,
1005 pub data_schema: DataSchema,
1006 pub fit_result: UnifiedFitResult,
1007 pub family: FittedFamily,
1008 pub model_class_label: String,
1009 pub likelihood_label: String,
1010 pub survival_entry: Option<String>,
1011 pub survival_exit: String,
1012 pub survival_event: String,
1013 pub baseline_cfg: SurvivalBaselineConfig,
1014 pub time_basis: SavedSurvivalTimeBasis,
1015 pub ridge_lambda: f64,
1016 pub beta_time: Vec<f64>,
1017 pub resolved_termspec: TermCollectionSpec,
1018}
1019
1020pub fn assemble_latent_window_payload(
1022 inputs: LatentWindowInputs,
1023 source: SavedModelSourceMetadata,
1024) -> FittedModelPayload {
1025 let mut payload = FittedModelPayload::new(
1026 MODEL_PAYLOAD_VERSION,
1027 inputs.formula,
1028 ModelKind::Survival,
1029 inputs.family,
1030 inputs.model_class_label,
1031 );
1032 payload.unified = Some(inputs.fit_result.clone());
1033 payload.fit_result = Some(inputs.fit_result);
1034 payload.data_schema = Some(inputs.data_schema);
1035 payload.survival_entry = inputs.survival_entry;
1036 payload.survival_exit = Some(inputs.survival_exit);
1037 payload.survival_event = Some(inputs.survival_event);
1038 payload.survivalspec = Some("net".to_string());
1039 payload.survival_baseline_target =
1040 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
1041 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
1042 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
1043 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
1044 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
1045 payload.apply_survival_time_basis(&inputs.time_basis);
1046 payload.survival_likelihood = Some(inputs.likelihood_label);
1047 payload.survival_beta_time = Some(inputs.beta_time);
1048 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
1049 payload.resolved_termspec = Some(inputs.resolved_termspec);
1050 source.apply_to(&mut payload);
1051 payload
1052}