1use gam_solve::estimate::UnifiedFitResult;
19use crate::bms::deviation_runtime::AnchorComponentTag;
20use crate::bms::{
21 DeviationRuntime, LatentMeasureKind, LatentZConditionalCalibration, LatentZRankIntCalibration,
22};
23use crate::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
24use crate::scale_design::ScaleDeviationTransform;
25use crate::survival::construction::{
26 SavedSurvivalTimeBasis, SurvivalBaselineConfig, survival_baseline_targetname,
27};
28use crate::survival::location_scale::{
29 ResidualDistribution, residual_distribution_from_inverse_link,
30};
31use crate::transformation_normal::TransformationNormalFamily;
32use crate::inference::model::{
33 FittedFamily, FittedModelPayload, MODEL_PAYLOAD_VERSION, ModelKind, SavedAnchorComponent,
34 SavedAnchorKind, SavedCompiledFlexBlock, SavedLatentZNormalization, SavedResidualCascade,
35 SavedSplineScan, TransformationScoreCalibration,
36};
37use gam_terms::smooth::TermCollectionSpec;
38use gam_problem::types::{
39 InverseLink, LikelihoodSpec, ResponseFamily, StandardLink, inverse_link_to_binomial_spec,
40};
41use gam_data::DataSchema;
42use ndarray::Array2;
43
44const FAMILY_BERNOULLI_MARGINAL_SLOPE: &str = "bernoulli-marginal-slope";
46
47const FAMILY_TRANSFORMATION_NORMAL: &str = "transformation-normal";
49
50pub fn serialize_anchored_deviation_runtime(runtime: &DeviationRuntime) -> SavedCompiledFlexBlock {
57 let mut anchor_correction: Option<Vec<Vec<f64>>> = None;
58 let mut anchor_components: Vec<SavedAnchorComponent> = Vec::new();
59 if let Some(installed) = runtime.installed_flex_block() {
60 anchor_correction = Some(
61 installed
62 .anchor_correction
63 .rows()
64 .into_iter()
65 .map(|row| row.to_vec())
66 .collect::<Vec<Vec<f64>>>(),
67 );
68 for component in &installed.anchor_components {
69 anchor_components.push(SavedAnchorComponent {
70 kind: match component {
71 AnchorComponentTag::Parametric { block, ncols } => {
72 SavedAnchorKind::Parametric {
73 block: *block,
74 ncols: *ncols,
75 }
76 }
77 AnchorComponentTag::FlexEvaluation { ncols } => {
78 SavedAnchorKind::FlexEvaluation { ncols: *ncols }
79 }
80 },
81 });
82 }
83 }
84 SavedCompiledFlexBlock {
85 kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
86 breakpoints: runtime.breakpoints().to_vec(),
87 basis_dim: runtime.basis_dim(),
88 span_c0: runtime
89 .span_c0()
90 .rows()
91 .into_iter()
92 .map(|row| row.to_vec())
93 .collect(),
94 span_c1: runtime
95 .span_c1()
96 .rows()
97 .into_iter()
98 .map(|row| row.to_vec())
99 .collect(),
100 span_c2: runtime
101 .span_c2()
102 .rows()
103 .into_iter()
104 .map(|row| row.to_vec())
105 .collect(),
106 span_c3: runtime
107 .span_c3()
108 .rows()
109 .into_iter()
110 .map(|row| row.to_vec())
111 .collect(),
112 anchor_correction,
113 anchor_components,
114 }
115}
116
117pub struct SavedModelSourceMetadata {
125 pub training_headers: Vec<String>,
126 pub training_feature_ranges: Option<Vec<(f64, f64)>>,
127 pub offset_column: Option<String>,
128 pub noise_offset_column: Option<String>,
129}
130
131impl SavedModelSourceMetadata {
132 fn apply_to(self, payload: &mut FittedModelPayload) {
133 match self.training_feature_ranges {
134 Some(ranges) => payload.set_training_feature_metadata(self.training_headers, ranges),
135 None => payload.training_headers = Some(self.training_headers),
136 }
137 payload.offset_column = self.offset_column;
138 payload.noise_offset_column = self.noise_offset_column;
139 }
140}
141
142pub struct BernoulliMarginalSlopeInputs<'a> {
149 pub formula: String,
150 pub data_schema: DataSchema,
151 pub logslope_formula: String,
152 pub z_column: String,
153 pub resolved_marginalspec: TermCollectionSpec,
154 pub resolved_logslopespec: TermCollectionSpec,
155 pub fit_result: UnifiedFitResult,
156 pub p_marginal: usize,
169 pub baseline_marginal: f64,
170 pub baseline_logslope: f64,
171 pub latent_z_normalization: SavedLatentZNormalization,
172 pub latent_measure: LatentMeasureKind,
173 pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
174 pub latent_z_conditional_calibration: Option<LatentZConditionalCalibration>,
175 pub score_warp_runtime: Option<&'a DeviationRuntime>,
176 pub link_dev_runtime: Option<&'a DeviationRuntime>,
177 pub base_link: InverseLink,
178 pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
179}
180
181fn truncate_marginal_slope_influence_absorber(
214 fit_result: UnifiedFitResult,
215 p_marginal: usize,
216) -> Result<UnifiedFitResult, String> {
217 let Some(block0) = fit_result.blocks.first() else {
218 return Err("marginal-slope fit result has no coefficient blocks".to_string());
219 };
220 let widened_len = block0.beta.len();
221 if widened_len <= p_marginal {
222 return Ok(fit_result);
224 }
225 let p_influence = widened_len - p_marginal;
226
227 let UnifiedFitResult {
228 mut blocks,
229 log_lambdas,
230 lambdas,
231 likelihood_family,
232 likelihood_scale,
233 log_likelihood_normalization,
234 log_likelihood,
235 deviance,
236 reml_score,
237 stable_penalty_term,
238 penalized_objective,
239 used_device,
240 outer_iterations,
241 outer_converged,
242 outer_gradient_norm,
243 standard_deviation,
244 covariance_conditional,
245 covariance_corrected,
246 inference,
247 fitted_link,
248 geometry: _,
249 mut block_states,
250 beta: _,
251 pirls_status,
252 max_abs_eta,
253 constraint_kkt,
254 artifacts,
255 inner_cycles,
256 outer_cost_evals: _,
257 inner_pirls_solves: _,
258 } = fit_result;
259
260 blocks[0].beta = blocks[0].beta.slice(ndarray::s![..p_marginal]).to_owned();
263 if let Some(state0) = block_states.first_mut() {
264 state0.beta = state0.beta.slice(ndarray::s![..p_marginal]).to_owned();
265 }
266
267 let drop_gamma_block = |cov: Option<Array2<f64>>| -> Option<Array2<f64>> {
270 cov.map(|cov| {
271 let total = cov.nrows();
272 let kept: Vec<usize> = (0..p_marginal)
273 .chain((p_marginal + p_influence)..total)
274 .collect();
275 let mut out = Array2::<f64>::zeros((kept.len(), kept.len()));
276 for (ri, &r) in kept.iter().enumerate() {
277 for (ci, &c) in kept.iter().enumerate() {
278 out[[ri, ci]] = cov[[r, c]];
279 }
280 }
281 out
282 })
283 };
284 let covariance_conditional = drop_gamma_block(covariance_conditional);
285 let covariance_corrected = drop_gamma_block(covariance_corrected);
286
287 UnifiedFitResult::try_from_parts(gam_solve::estimate::UnifiedFitResultParts {
288 blocks,
289 log_lambdas,
290 lambdas,
291 likelihood_family,
292 likelihood_scale,
293 log_likelihood_normalization,
294 log_likelihood,
295 deviance,
296 reml_score,
297 stable_penalty_term,
298 penalized_objective,
299 used_device,
303 outer_iterations,
304 outer_converged,
305 outer_gradient_norm,
306 standard_deviation,
307 covariance_conditional,
308 covariance_corrected,
309 inference,
310 fitted_link,
311 geometry: None,
313 block_states,
314 pirls_status,
315 max_abs_eta,
316 constraint_kkt,
317 artifacts,
318 inner_cycles,
319 })
320 .map_err(|e| {
321 format!("marginal-slope influence-absorber truncation produced an invalid fit result: {e}")
322 })
323}
324
325pub fn assemble_spline_scan_payload(
331 formula: String,
332 feature_column: String,
333 fit: &gam_solve::spline_scan::SplineScanFit,
334 data_schema: DataSchema,
335 training_headers: Vec<String>,
336 training_feature_ranges: Vec<(f64, f64)>,
337) -> FittedModelPayload {
338 let mut payload = FittedModelPayload::new(
339 MODEL_PAYLOAD_VERSION,
340 formula,
341 ModelKind::Standard,
342 FittedFamily::Standard {
343 likelihood: LikelihoodSpec::gaussian_identity(),
344 link: None,
345 latent_cloglog_state: None,
346 mixture_state: None,
347 sas_state: None,
348 },
349 "gaussian".to_string(),
350 );
351 payload.spline_scan = Some(SavedSplineScan {
352 feature_column,
353 state: fit.to_state(),
354 });
355 payload.data_schema = Some(data_schema);
356 payload.set_training_feature_metadata(training_headers, training_feature_ranges);
357 payload
358}
359
360pub fn assemble_residual_cascade_payload(
366 formula: String,
367 feature_columns: Vec<String>,
368 fit: &gam_solve::residual_cascade::ResidualCascadeFit,
369 data_schema: DataSchema,
370 training_headers: Vec<String>,
371 training_feature_ranges: Vec<(f64, f64)>,
372) -> Result<FittedModelPayload, String> {
373 let mut payload = FittedModelPayload::new(
374 MODEL_PAYLOAD_VERSION,
375 formula,
376 ModelKind::Standard,
377 FittedFamily::Standard {
378 likelihood: gam_problem::types::LikelihoodSpec::gaussian_identity(),
379 link: None,
380 latent_cloglog_state: None,
381 mixture_state: None,
382 sas_state: None,
383 },
384 "gaussian".to_string(),
385 );
386 payload.residual_cascade = Some(SavedResidualCascade {
387 feature_columns,
388 state: fit.to_state().map_err(|e| {
389 format!("residual-cascade to_state failed during payload assembly: {e}")
390 })?,
391 });
392 payload.data_schema = Some(data_schema);
393 payload.set_training_feature_metadata(training_headers, training_feature_ranges);
394 Ok(payload)
395}
396
397pub fn assemble_bernoulli_marginal_slope_payload(
405 inputs: BernoulliMarginalSlopeInputs<'_>,
406 source: SavedModelSourceMetadata,
407) -> Result<FittedModelPayload, String> {
408 let BernoulliMarginalSlopeInputs {
409 formula,
410 data_schema,
411 logslope_formula,
412 z_column,
413 resolved_marginalspec,
414 resolved_logslopespec,
415 fit_result,
416 p_marginal,
417 baseline_marginal,
418 baseline_logslope,
419 latent_z_normalization,
420 latent_measure,
421 latent_z_rank_int_calibration,
422 latent_z_conditional_calibration,
423 score_warp_runtime,
424 link_dev_runtime,
425 base_link,
426 frailty,
427 } = inputs;
428
429 let fit_result = truncate_marginal_slope_influence_absorber(fit_result, p_marginal)?;
433
434 let marginal_likelihood_spec =
435 inverse_link_to_binomial_spec(&base_link).map_err(|e| e.to_string())?;
436
437 let mut payload = FittedModelPayload::new(
438 MODEL_PAYLOAD_VERSION,
439 formula,
440 ModelKind::MarginalSlope,
441 FittedFamily::MarginalSlope {
442 likelihood: marginal_likelihood_spec,
443 base_link: base_link.clone(),
444 frailty,
445 },
446 FAMILY_BERNOULLI_MARGINAL_SLOPE.to_string(),
447 );
448 payload.unified = Some(fit_result.clone());
449 payload.fit_result = Some(fit_result);
450 payload.data_schema = Some(data_schema);
451 payload.formula_logslope = Some(logslope_formula.clone());
452 payload.z_column = Some(z_column.clone());
453 payload.formula_logslopes = Some(vec![logslope_formula]);
454 payload.z_columns = Some(vec![z_column]);
455 payload.latent_z_normalization = Some(latent_z_normalization);
456 payload.latent_measure = Some(latent_measure);
457 payload.latent_z_rank_int_calibration = latent_z_rank_int_calibration;
458 payload.latent_z_conditional_calibration = latent_z_conditional_calibration;
459 payload.marginal_baseline = Some(baseline_marginal);
460 payload.logslope_baseline = Some(baseline_logslope);
461 payload.logslope_baselines = Some(vec![baseline_logslope]);
462 payload.link = Some(base_link);
463 payload.resolved_termspec = Some(resolved_marginalspec);
464 payload.resolved_termspec_logslopes = Some(vec![resolved_logslopespec.clone()]);
465 payload.resolved_termspec_logslope = Some(resolved_logslopespec);
466 payload.score_warp_runtime = score_warp_runtime.map(serialize_anchored_deviation_runtime);
467 payload.link_deviation_runtime = link_dev_runtime.map(serialize_anchored_deviation_runtime);
468 source.apply_to(&mut payload);
469 Ok(payload)
470}
471
472pub struct TransformationNormalInputs<'a> {
479 pub formula: String,
480 pub data_schema: DataSchema,
481 pub resolved_covariate_spec: TermCollectionSpec,
482 pub fit_result: UnifiedFitResult,
483 pub family: &'a TransformationNormalFamily,
484 pub score_calibration: TransformationScoreCalibration,
485}
486
487pub fn assemble_transformation_normal_payload(
493 inputs: TransformationNormalInputs<'_>,
494 source: SavedModelSourceMetadata,
495) -> FittedModelPayload {
496 let TransformationNormalInputs {
497 formula,
498 data_schema,
499 resolved_covariate_spec,
500 fit_result,
501 family,
502 score_calibration,
503 } = inputs;
504
505 let mut payload = FittedModelPayload::new(
506 MODEL_PAYLOAD_VERSION,
507 formula,
508 ModelKind::TransformationNormal,
509 FittedFamily::TransformationNormal {
510 likelihood: LikelihoodSpec::new(
511 ResponseFamily::Gaussian,
512 InverseLink::Standard(StandardLink::Identity),
513 ),
514 },
515 FAMILY_TRANSFORMATION_NORMAL.to_string(),
516 );
517 payload.unified = Some(fit_result.clone());
518 payload.fit_result = Some(fit_result);
519 payload.data_schema = Some(data_schema);
520 payload.resolved_termspec = Some(resolved_covariate_spec);
521 payload.transformation_response_knots = Some(family.response_knots().to_vec());
522 payload.transformation_response_transform = Some(
523 family
524 .response_transform()
525 .rows()
526 .into_iter()
527 .map(|row| row.to_vec())
528 .collect(),
529 );
530 payload.transformation_response_degree = Some(family.response_degree());
531 payload.transformation_response_median = Some(family.response_median());
532 payload.transformation_score_calibration = Some(score_calibration);
533 source.apply_to(&mut payload);
534 payload
535}
536
537pub enum LocationScaleResponse<'a> {
543 Gaussian {
546 response_scale: f64,
547 base_link: Option<InverseLink>,
548 },
549 Binomial {
551 link: InverseLink,
552 noise_transform: &'a ScaleDeviationTransform,
553 },
554 Dispersion {
560 likelihood: LikelihoodSpec,
561 base_link: InverseLink,
562 family_tag: &'static str,
563 },
564}
565
566pub struct LocationScaleWiggle {
571 pub knots: Vec<f64>,
572 pub degree: usize,
573 pub beta_link_wiggle: Vec<f64>,
574}
575
576pub struct LocationScaleInputs {
580 pub formula: String,
581 pub data_schema: DataSchema,
582 pub noise_formula: String,
583 pub resolved_termspec: TermCollectionSpec,
584 pub resolved_termspec_noise: TermCollectionSpec,
585 pub fit_result: UnifiedFitResult,
586 pub beta_noise: Option<Vec<f64>>,
587 pub wiggle: Option<LocationScaleWiggle>,
588}
589
590pub fn assemble_location_scale_payload(
595 inputs: LocationScaleInputs,
596 response: LocationScaleResponse<'_>,
597 source: SavedModelSourceMetadata,
598) -> Result<FittedModelPayload, String> {
599 let (family_tag, likelihood, base_link, link, response_scale, noise_transform) = match response
600 {
601 LocationScaleResponse::Gaussian {
602 response_scale,
603 base_link,
604 } => (
605 "gaussian-location-scale".to_string(),
606 LikelihoodSpec::gaussian_identity(),
607 None,
611 Some(base_link.unwrap_or(InverseLink::Standard(StandardLink::Identity))),
612 Some(response_scale),
613 None,
614 ),
615 LocationScaleResponse::Binomial {
616 link,
617 noise_transform,
618 } => {
619 let likelihood = inverse_link_to_binomial_spec(&link).map_err(|e| {
620 format!("failed to resolve LikelihoodSpec for binomial location-scale link {link:?}: {e}")
621 })?;
622 (
623 "binomial-location-scale".to_string(),
624 likelihood,
625 Some(link.clone()),
626 Some(link),
627 None,
628 Some(noise_transform),
629 )
630 }
631 LocationScaleResponse::Dispersion {
632 likelihood,
633 base_link,
634 family_tag,
635 } => (
636 family_tag.to_string(),
637 likelihood,
638 Some(base_link.clone()),
639 Some(base_link),
640 None,
641 None,
642 ),
643 };
644
645 let mut payload = FittedModelPayload::new(
646 MODEL_PAYLOAD_VERSION,
647 inputs.formula,
648 ModelKind::LocationScale,
649 FittedFamily::LocationScale {
650 likelihood,
651 base_link,
652 },
653 family_tag,
654 );
655 payload.unified = Some(inputs.fit_result.clone());
656 payload.fit_result = Some(inputs.fit_result);
657 payload.data_schema = Some(inputs.data_schema);
658 payload.link = link;
659 payload.formula_noise = Some(inputs.noise_formula);
660 payload.beta_noise = inputs.beta_noise;
661 payload.gaussian_response_scale = response_scale;
662 if let Some(transform) = noise_transform {
663 payload.noise_projection = Some(
664 transform
665 .projection_coef
666 .rows()
667 .into_iter()
668 .map(|row| row.to_vec())
669 .collect(),
670 );
671 payload.noise_center = Some(transform.weighted_column_mean.to_vec());
672 payload.noise_scale = Some(transform.rescale.to_vec());
673 payload.noise_non_intercept_start = Some(transform.non_intercept_start);
674 payload.noise_projection_ridge_alpha = Some(transform.projection_ridge_alpha);
675 }
676 payload.resolved_termspec = Some(inputs.resolved_termspec);
677 payload.resolved_termspec_noise = Some(inputs.resolved_termspec_noise);
678 if let Some(wiggle) = inputs.wiggle {
679 payload.linkwiggle_knots = Some(wiggle.knots);
680 payload.linkwiggle_degree = Some(wiggle.degree);
681 payload.beta_link_wiggle = Some(wiggle.beta_link_wiggle);
682 }
683 source.apply_to(&mut payload);
684 Ok(payload)
685}
686
687pub struct SurvivalMarginalSlopeInputs<'a> {
692 pub formula: String,
693 pub data_schema: DataSchema,
694 pub fit_result: UnifiedFitResult,
695 pub frailty: crate::survival::lognormal_kernel::FrailtySpec,
696 pub survival_entry: Option<String>,
697 pub survival_exit: String,
698 pub survival_event: String,
699 pub survivalspec: String,
700 pub baseline_cfg: SurvivalBaselineConfig,
701 pub time_basis: SavedSurvivalTimeBasis,
702 pub ridge_lambda: f64,
703 pub survival_likelihood_label: String,
704 pub resolved_marginalspec: TermCollectionSpec,
705 pub resolved_logslopespec: TermCollectionSpec,
706 pub logslope_formula: String,
707 pub z_column: String,
708 pub latent_z_normalization: SavedLatentZNormalization,
709 pub baseline_logslope: f64,
710 pub score_warp_runtime: Option<&'a DeviationRuntime>,
711 pub link_dev_runtime: Option<&'a DeviationRuntime>,
712 pub influence_absorber_width: Option<usize>,
717}
718
719fn new_royston_parmar_survival_payload(
727 formula: String,
728 fit_result: UnifiedFitResult,
729 data_schema: DataSchema,
730 survival_likelihood_label: &str,
731 survival_distribution: Option<ResidualDistribution>,
732 frailty: crate::survival::lognormal_kernel::FrailtySpec,
733) -> FittedModelPayload {
734 let mut payload = FittedModelPayload::new(
735 MODEL_PAYLOAD_VERSION,
736 formula,
737 ModelKind::Survival,
738 FittedFamily::Survival {
739 likelihood: LikelihoodSpec::new(
740 ResponseFamily::RoystonParmar,
741 InverseLink::Standard(StandardLink::Identity),
742 ),
743 survival_likelihood: Some(survival_likelihood_label.to_string()),
744 survival_distribution,
745 frailty,
746 },
747 ResponseFamily::RoystonParmar.name().to_string(),
748 );
749 payload.unified = Some(fit_result.clone());
750 payload.fit_result = Some(fit_result);
751 payload.data_schema = Some(data_schema);
752 payload
753}
754
755pub fn assemble_survival_marginal_slope_payload(
758 inputs: SurvivalMarginalSlopeInputs<'_>,
759 source: SavedModelSourceMetadata,
760) -> FittedModelPayload {
761 let mut payload = new_royston_parmar_survival_payload(
762 inputs.formula,
763 inputs.fit_result,
764 inputs.data_schema,
765 &inputs.survival_likelihood_label,
766 Some(ResidualDistribution::Gaussian),
767 inputs.frailty,
768 );
769 payload.survival_entry = inputs.survival_entry;
770 payload.survival_exit = Some(inputs.survival_exit);
771 payload.survival_event = Some(inputs.survival_event);
772 payload.survivalspec = Some(inputs.survivalspec);
773 payload.survival_baseline_target =
774 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
775 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
776 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
777 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
778 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
779 payload.apply_survival_time_basis(&inputs.time_basis);
780 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
781 payload.survival_likelihood = Some(inputs.survival_likelihood_label);
782 payload.survival_distribution = Some(ResidualDistribution::Gaussian);
783 payload.link = Some(InverseLink::Standard(StandardLink::Probit));
784 payload.resolved_termspec = Some(inputs.resolved_marginalspec);
785 payload.resolved_termspec_logslopes = Some(vec![inputs.resolved_logslopespec.clone()]);
786 payload.resolved_termspec_logslope = Some(inputs.resolved_logslopespec);
787 payload.formula_logslope = Some(inputs.logslope_formula.clone());
788 payload.formula_logslopes = Some(vec![inputs.logslope_formula]);
789 payload.z_column = Some(inputs.z_column.clone());
790 payload.z_columns = Some(vec![inputs.z_column]);
791 payload.latent_z_normalization = Some(inputs.latent_z_normalization);
792 payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
793 payload.logslope_baseline = Some(inputs.baseline_logslope);
794 payload.logslope_baselines = Some(vec![inputs.baseline_logslope]);
795 payload.score_warp_runtime = inputs
796 .score_warp_runtime
797 .map(serialize_anchored_deviation_runtime);
798 payload.link_deviation_runtime = inputs
799 .link_dev_runtime
800 .map(serialize_anchored_deviation_runtime);
801 payload.influence_absorber_width = inputs.influence_absorber_width;
802 source.apply_to(&mut payload);
803 payload
804}
805
806pub enum SurvivalTimewiggleBeta {
809 Single(Vec<f64>),
810 ByCause(Vec<Vec<f64>>),
811}
812
813pub struct SurvivalTimewiggle {
815 pub degree: usize,
816 pub knots: Vec<f64>,
817 pub penalty_orders: Option<Vec<usize>>,
818 pub double_penalty: Option<bool>,
819 pub beta: SurvivalTimewiggleBeta,
820}
821
822pub struct SurvivalTransformationInputs {
825 pub formula: String,
826 pub data_schema: DataSchema,
827 pub fit_result: UnifiedFitResult,
828 pub survival_entry: Option<String>,
829 pub survival_exit: String,
830 pub survival_event: String,
831 pub survivalspec: String,
832 pub cause_count: Option<usize>,
835 pub baseline_cfg: SurvivalBaselineConfig,
836 pub time_basis: SavedSurvivalTimeBasis,
837 pub ridge_lambda: f64,
838 pub survival_likelihood_label: String,
839 pub resolved_termspec: TermCollectionSpec,
840 pub survival_beta_time: Option<Vec<f64>>,
842 pub timewiggle: Option<SurvivalTimewiggle>,
843}
844
845pub fn assemble_survival_transformation_payload(
848 inputs: SurvivalTransformationInputs,
849 source: SavedModelSourceMetadata,
850) -> FittedModelPayload {
851 let mut payload = new_royston_parmar_survival_payload(
852 inputs.formula,
853 inputs.fit_result,
854 inputs.data_schema,
855 &inputs.survival_likelihood_label,
856 None,
857 crate::survival::lognormal_kernel::FrailtySpec::None,
858 );
859 payload.survival_entry = inputs.survival_entry;
860 payload.survival_exit = Some(inputs.survival_exit);
861 payload.survival_event = Some(inputs.survival_event);
862 payload.survivalspec = Some(inputs.survivalspec);
863 if let Some(cause_count) = inputs.cause_count {
864 payload.survival_cause_count = Some(cause_count);
865 payload.survival_endpoint_names = Some(
866 (1..=cause_count)
867 .map(|idx| format!("cause_{idx}"))
868 .collect(),
869 );
870 }
871 payload.survival_baseline_target =
872 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
873 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
874 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
875 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
876 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
877 payload.apply_survival_time_basis(&inputs.time_basis);
878 if let Some(timewiggle) = inputs.timewiggle {
879 payload.baseline_timewiggle_degree = Some(timewiggle.degree);
880 payload.baseline_timewiggle_knots = Some(timewiggle.knots);
881 payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
882 payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
883 match timewiggle.beta {
884 SurvivalTimewiggleBeta::Single(beta) => {
885 payload.beta_baseline_timewiggle = Some(beta);
886 }
887 SurvivalTimewiggleBeta::ByCause(by_cause) => {
888 payload.beta_baseline_timewiggle_by_cause = Some(by_cause);
889 }
890 }
891 }
892 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
893 payload.survival_likelihood = Some(inputs.survival_likelihood_label);
894 payload.survival_beta_time = inputs.survival_beta_time;
895 payload.resolved_termspec = Some(inputs.resolved_termspec);
896 source.apply_to(&mut payload);
897 payload
898}
899
900pub struct SurvivalLocationScaleInputs<'a> {
905 pub formula: String,
906 pub data_schema: DataSchema,
907 pub fit_result: UnifiedFitResult,
910 pub fitted_inverse_link: InverseLink,
911 pub linkwiggle_degree: Option<usize>,
914 pub linkwiggle_knots: Option<Vec<f64>>,
915 pub beta_link_wiggle: Option<Vec<f64>>,
916 pub baseline_timewiggle: Option<SurvivalTimewiggle>,
917 pub survival_entry: Option<String>,
918 pub survival_exit: String,
919 pub survival_event: String,
920 pub survivalspec: String,
921 pub baseline_cfg: SurvivalBaselineConfig,
922 pub time_basis: SavedSurvivalTimeBasis,
923 pub ridge_lambda: f64,
924 pub survival_likelihood_label: String,
925 pub formula_noise: Option<String>,
926 pub survival_beta_time: Vec<f64>,
927 pub survival_beta_threshold: Vec<f64>,
928 pub survival_beta_log_sigma: Vec<f64>,
929 pub noise_transform: &'a ScaleDeviationTransform,
930 pub resolved_thresholdspec: TermCollectionSpec,
931 pub resolved_log_sigmaspec: TermCollectionSpec,
932}
933
934pub fn assemble_survival_location_scale_payload(
937 inputs: SurvivalLocationScaleInputs<'_>,
938 source: SavedModelSourceMetadata,
939) -> FittedModelPayload {
940 let survival_distribution =
941 residual_distribution_from_inverse_link(&inputs.fitted_inverse_link);
942 let mut payload = new_royston_parmar_survival_payload(
943 inputs.formula,
944 inputs.fit_result,
945 inputs.data_schema,
946 &inputs.survival_likelihood_label,
947 survival_distribution,
948 crate::survival::lognormal_kernel::FrailtySpec::None,
949 );
950 payload.link = Some(inputs.fitted_inverse_link);
951 payload.linkwiggle_degree = inputs.linkwiggle_degree;
952 payload.linkwiggle_knots = inputs.linkwiggle_knots;
953 payload.beta_link_wiggle = inputs.beta_link_wiggle;
954 if let Some(timewiggle) = inputs.baseline_timewiggle {
955 payload.baseline_timewiggle_degree = Some(timewiggle.degree);
956 payload.baseline_timewiggle_knots = Some(timewiggle.knots);
957 payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
958 payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
959 if let SurvivalTimewiggleBeta::Single(beta) = timewiggle.beta {
960 payload.beta_baseline_timewiggle = Some(beta);
961 }
962 }
963 payload.survival_entry = inputs.survival_entry;
964 payload.survival_exit = Some(inputs.survival_exit);
965 payload.survival_event = Some(inputs.survival_event);
966 payload.survivalspec = Some(inputs.survivalspec);
967 payload.survival_baseline_target =
968 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
969 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
970 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
971 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
972 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
973 payload.apply_survival_time_basis(&inputs.time_basis);
974 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
975 payload.survival_likelihood = Some(inputs.survival_likelihood_label);
976 payload.formula_noise = inputs.formula_noise;
977 payload.survival_beta_time = Some(inputs.survival_beta_time);
978 payload.survival_beta_threshold = Some(inputs.survival_beta_threshold);
979 payload.survival_beta_log_sigma = Some(inputs.survival_beta_log_sigma);
980 payload.survival_noise_projection = Some(
981 inputs
982 .noise_transform
983 .projection_coef
984 .rows()
985 .into_iter()
986 .map(|row| row.to_vec())
987 .collect(),
988 );
989 payload.survival_noise_center = Some(inputs.noise_transform.weighted_column_mean.to_vec());
990 payload.survival_noise_scale = Some(inputs.noise_transform.rescale.to_vec());
991 payload.survival_noise_non_intercept_start = Some(inputs.noise_transform.non_intercept_start);
992 payload.survival_noise_projection_ridge_alpha =
993 Some(inputs.noise_transform.projection_ridge_alpha);
994 payload.survival_distribution = survival_distribution;
995 payload.resolved_termspec = Some(inputs.resolved_thresholdspec);
996 payload.resolved_termspec_noise = Some(inputs.resolved_log_sigmaspec);
997 source.apply_to(&mut payload);
998 payload
999}
1000
1001pub struct LatentWindowInputs {
1005 pub formula: String,
1006 pub data_schema: DataSchema,
1007 pub fit_result: UnifiedFitResult,
1008 pub family: FittedFamily,
1009 pub model_class_label: String,
1010 pub likelihood_label: String,
1011 pub survival_entry: Option<String>,
1012 pub survival_exit: String,
1013 pub survival_event: String,
1014 pub baseline_cfg: SurvivalBaselineConfig,
1015 pub time_basis: SavedSurvivalTimeBasis,
1016 pub ridge_lambda: f64,
1017 pub beta_time: Vec<f64>,
1018 pub resolved_termspec: TermCollectionSpec,
1019}
1020
1021pub fn assemble_latent_window_payload(
1023 inputs: LatentWindowInputs,
1024 source: SavedModelSourceMetadata,
1025) -> FittedModelPayload {
1026 let mut payload = FittedModelPayload::new(
1027 MODEL_PAYLOAD_VERSION,
1028 inputs.formula,
1029 ModelKind::Survival,
1030 inputs.family,
1031 inputs.model_class_label,
1032 );
1033 payload.unified = Some(inputs.fit_result.clone());
1034 payload.fit_result = Some(inputs.fit_result);
1035 payload.data_schema = Some(inputs.data_schema);
1036 payload.survival_entry = inputs.survival_entry;
1037 payload.survival_exit = Some(inputs.survival_exit);
1038 payload.survival_event = Some(inputs.survival_event);
1039 payload.survivalspec = Some("net".to_string());
1040 payload.survival_baseline_target =
1041 Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
1042 payload.survival_baseline_scale = inputs.baseline_cfg.scale;
1043 payload.survival_baseline_shape = inputs.baseline_cfg.shape;
1044 payload.survival_baseline_rate = inputs.baseline_cfg.rate;
1045 payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
1046 payload.apply_survival_time_basis(&inputs.time_basis);
1047 payload.survival_likelihood = Some(inputs.likelihood_label);
1048 payload.survival_beta_time = Some(inputs.beta_time);
1049 payload.survivalridge_lambda = Some(inputs.ridge_lambda);
1050 payload.resolved_termspec = Some(inputs.resolved_termspec);
1051 source.apply_to(&mut payload);
1052 payload
1053}