use crate::estimate::UnifiedFitResult;
use crate::families::bms::deviation_runtime::AnchorComponentTag;
use crate::families::bms::{DeviationRuntime, LatentMeasureKind, LatentZRankIntCalibration};
use crate::families::cubic_cell_kernel::ANCHORED_DEVIATION_KERNEL;
use crate::families::scale_design::ScaleDeviationTransform;
use crate::families::survival_construction::{
SavedSurvivalTimeBasis, SurvivalBaselineConfig, survival_baseline_targetname,
};
use crate::families::survival_location_scale::{
ResidualDistribution, residual_distribution_from_inverse_link,
};
use crate::families::transformation_normal::TransformationNormalFamily;
use crate::inference::model::{
DataSchema, FittedFamily, FittedModelPayload, MODEL_PAYLOAD_VERSION, ModelKind,
SavedAnchorComponent, SavedAnchorKind, SavedCompiledFlexBlock, SavedLatentZNormalization,
TransformationScoreCalibration,
};
use crate::smooth::TermCollectionSpec;
use crate::types::{
InverseLink, LikelihoodSpec, ResponseFamily, StandardLink, inverse_link_to_binomial_spec,
};
use ndarray::Array2;
const FAMILY_BERNOULLI_MARGINAL_SLOPE: &str = "bernoulli-marginal-slope";
const FAMILY_TRANSFORMATION_NORMAL: &str = "transformation-normal";
pub fn serialize_anchored_deviation_runtime(runtime: &DeviationRuntime) -> SavedCompiledFlexBlock {
let mut anchor_correction: Option<Vec<Vec<f64>>> = None;
let mut anchor_components: Vec<SavedAnchorComponent> = Vec::new();
if let Some(installed) = runtime.installed_flex_block() {
anchor_correction = Some(
installed
.anchor_correction
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect::<Vec<Vec<f64>>>(),
);
for component in &installed.anchor_components {
anchor_components.push(SavedAnchorComponent {
kind: match component {
AnchorComponentTag::Parametric { block, ncols } => {
SavedAnchorKind::Parametric {
block: *block,
ncols: *ncols,
}
}
AnchorComponentTag::FlexEvaluation { ncols } => {
SavedAnchorKind::FlexEvaluation { ncols: *ncols }
}
},
});
}
}
SavedCompiledFlexBlock {
kernel: ANCHORED_DEVIATION_KERNEL.to_string(),
breakpoints: runtime.breakpoints().to_vec(),
basis_dim: runtime.basis_dim(),
span_c0: runtime
.span_c0()
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect(),
span_c1: runtime
.span_c1()
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect(),
span_c2: runtime
.span_c2()
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect(),
span_c3: runtime
.span_c3()
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect(),
anchor_correction,
anchor_components,
}
}
pub struct SavedModelSourceMetadata {
pub training_headers: Vec<String>,
pub training_feature_ranges: Option<Vec<(f64, f64)>>,
pub offset_column: Option<String>,
pub noise_offset_column: Option<String>,
}
impl SavedModelSourceMetadata {
fn apply_to(self, payload: &mut FittedModelPayload) {
match self.training_feature_ranges {
Some(ranges) => payload.set_training_feature_metadata(self.training_headers, ranges),
None => payload.training_headers = Some(self.training_headers),
}
payload.offset_column = self.offset_column;
payload.noise_offset_column = self.noise_offset_column;
}
}
pub struct BernoulliMarginalSlopeInputs<'a> {
pub formula: String,
pub data_schema: DataSchema,
pub logslope_formula: String,
pub z_column: String,
pub resolved_marginalspec: TermCollectionSpec,
pub resolved_logslopespec: TermCollectionSpec,
pub fit_result: UnifiedFitResult,
pub p_marginal: usize,
pub baseline_marginal: f64,
pub baseline_logslope: f64,
pub latent_z_normalization: SavedLatentZNormalization,
pub latent_measure: LatentMeasureKind,
pub latent_z_rank_int_calibration: Option<LatentZRankIntCalibration>,
pub score_warp_runtime: Option<&'a DeviationRuntime>,
pub link_dev_runtime: Option<&'a DeviationRuntime>,
pub base_link: InverseLink,
pub frailty: crate::families::lognormal_kernel::FrailtySpec,
}
fn truncate_marginal_slope_influence_absorber(
fit_result: UnifiedFitResult,
p_marginal: usize,
) -> Result<UnifiedFitResult, String> {
let Some(block0) = fit_result.blocks.first() else {
return Err("marginal-slope fit result has no coefficient blocks".to_string());
};
let widened_len = block0.beta.len();
if widened_len <= p_marginal {
return Ok(fit_result);
}
let p_influence = widened_len - p_marginal;
let UnifiedFitResult {
mut blocks,
log_lambdas,
lambdas,
likelihood_family,
likelihood_scale,
log_likelihood_normalization,
log_likelihood,
deviance,
reml_score,
stable_penalty_term,
penalized_objective,
outer_iterations,
outer_converged,
outer_gradient_norm,
standard_deviation,
covariance_conditional,
covariance_corrected,
inference,
fitted_link,
geometry: _,
mut block_states,
beta: _,
pirls_status,
max_abs_eta,
constraint_kkt,
artifacts,
inner_cycles,
} = fit_result;
blocks[0].beta = blocks[0].beta.slice(ndarray::s![..p_marginal]).to_owned();
if let Some(state0) = block_states.first_mut() {
state0.beta = state0.beta.slice(ndarray::s![..p_marginal]).to_owned();
}
let drop_gamma_block = |cov: Option<Array2<f64>>| -> Option<Array2<f64>> {
cov.map(|cov| {
let total = cov.nrows();
let kept: Vec<usize> = (0..p_marginal)
.chain((p_marginal + p_influence)..total)
.collect();
let mut out = Array2::<f64>::zeros((kept.len(), kept.len()));
for (ri, &r) in kept.iter().enumerate() {
for (ci, &c) in kept.iter().enumerate() {
out[[ri, ci]] = cov[[r, c]];
}
}
out
})
};
let covariance_conditional = drop_gamma_block(covariance_conditional);
let covariance_corrected = drop_gamma_block(covariance_corrected);
UnifiedFitResult::try_from_parts(crate::estimate::UnifiedFitResultParts {
blocks,
log_lambdas,
lambdas,
likelihood_family,
likelihood_scale,
log_likelihood_normalization,
log_likelihood,
deviance,
reml_score,
stable_penalty_term,
penalized_objective,
outer_iterations,
outer_converged,
outer_gradient_norm,
standard_deviation,
covariance_conditional,
covariance_corrected,
inference,
fitted_link,
geometry: None,
block_states,
pirls_status,
max_abs_eta,
constraint_kkt,
artifacts,
inner_cycles,
})
.map_err(|e| {
format!("marginal-slope influence-absorber truncation produced an invalid fit result: {e}")
})
}
pub fn assemble_bernoulli_marginal_slope_payload(
inputs: BernoulliMarginalSlopeInputs<'_>,
source: SavedModelSourceMetadata,
) -> Result<FittedModelPayload, String> {
let BernoulliMarginalSlopeInputs {
formula,
data_schema,
logslope_formula,
z_column,
resolved_marginalspec,
resolved_logslopespec,
fit_result,
p_marginal,
baseline_marginal,
baseline_logslope,
latent_z_normalization,
latent_measure,
latent_z_rank_int_calibration,
score_warp_runtime,
link_dev_runtime,
base_link,
frailty,
} = inputs;
let fit_result = truncate_marginal_slope_influence_absorber(fit_result, p_marginal)?;
let marginal_likelihood_spec =
inverse_link_to_binomial_spec(&base_link).map_err(|e| e.to_string())?;
let mut payload = FittedModelPayload::new(
MODEL_PAYLOAD_VERSION,
formula,
ModelKind::MarginalSlope,
FittedFamily::MarginalSlope {
likelihood: marginal_likelihood_spec,
base_link: Some(base_link.clone()),
frailty,
},
FAMILY_BERNOULLI_MARGINAL_SLOPE.to_string(),
);
payload.unified = Some(fit_result.clone());
payload.fit_result = Some(fit_result);
payload.data_schema = Some(data_schema);
payload.formula_logslope = Some(logslope_formula.clone());
payload.z_column = Some(z_column.clone());
payload.formula_logslopes = Some(vec![logslope_formula]);
payload.z_columns = Some(vec![z_column]);
payload.latent_z_normalization = Some(latent_z_normalization);
payload.latent_measure = Some(latent_measure);
payload.latent_z_rank_int_calibration = latent_z_rank_int_calibration;
payload.marginal_baseline = Some(baseline_marginal);
payload.logslope_baseline = Some(baseline_logslope);
payload.logslope_baselines = Some(vec![baseline_logslope]);
payload.link = Some(base_link);
payload.resolved_termspec = Some(resolved_marginalspec);
payload.resolved_termspec_logslopes = Some(vec![resolved_logslopespec.clone()]);
payload.resolved_termspec_logslope = Some(resolved_logslopespec);
payload.score_warp_runtime = score_warp_runtime.map(serialize_anchored_deviation_runtime);
payload.link_deviation_runtime = link_dev_runtime.map(serialize_anchored_deviation_runtime);
source.apply_to(&mut payload);
Ok(payload)
}
pub struct TransformationNormalInputs<'a> {
pub formula: String,
pub data_schema: DataSchema,
pub resolved_covariate_spec: TermCollectionSpec,
pub fit_result: UnifiedFitResult,
pub family: &'a TransformationNormalFamily,
pub score_calibration: TransformationScoreCalibration,
}
pub fn assemble_transformation_normal_payload(
inputs: TransformationNormalInputs<'_>,
source: SavedModelSourceMetadata,
) -> FittedModelPayload {
let TransformationNormalInputs {
formula,
data_schema,
resolved_covariate_spec,
fit_result,
family,
score_calibration,
} = inputs;
let mut payload = FittedModelPayload::new(
MODEL_PAYLOAD_VERSION,
formula,
ModelKind::TransformationNormal,
FittedFamily::TransformationNormal {
likelihood: LikelihoodSpec::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
),
},
FAMILY_TRANSFORMATION_NORMAL.to_string(),
);
payload.unified = Some(fit_result.clone());
payload.fit_result = Some(fit_result);
payload.data_schema = Some(data_schema);
payload.resolved_termspec = Some(resolved_covariate_spec);
payload.transformation_response_knots = Some(family.response_knots().to_vec());
payload.transformation_response_transform = Some(
family
.response_transform()
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect(),
);
payload.transformation_response_degree = Some(family.response_degree());
payload.transformation_response_median = Some(family.response_median());
payload.transformation_score_calibration = Some(score_calibration);
source.apply_to(&mut payload);
payload
}
pub enum LocationScaleResponse<'a> {
Gaussian {
response_scale: f64,
base_link: Option<InverseLink>,
},
Binomial {
link: InverseLink,
noise_transform: &'a ScaleDeviationTransform,
},
}
pub struct LocationScaleWiggle {
pub knots: Vec<f64>,
pub degree: usize,
pub beta_link_wiggle: Vec<f64>,
}
pub struct LocationScaleInputs {
pub formula: String,
pub data_schema: DataSchema,
pub noise_formula: String,
pub resolved_termspec: TermCollectionSpec,
pub resolved_termspec_noise: TermCollectionSpec,
pub fit_result: UnifiedFitResult,
pub beta_noise: Option<Vec<f64>>,
pub wiggle: Option<LocationScaleWiggle>,
}
pub fn assemble_location_scale_payload(
inputs: LocationScaleInputs,
response: LocationScaleResponse<'_>,
source: SavedModelSourceMetadata,
) -> Result<FittedModelPayload, String> {
let (family_tag, likelihood, base_link, link, response_scale, noise_transform) = match response
{
LocationScaleResponse::Gaussian {
response_scale,
base_link,
} => (
"gaussian-location-scale".to_string(),
LikelihoodSpec::gaussian_identity(),
None,
Some(base_link.unwrap_or(InverseLink::Standard(StandardLink::Identity))),
Some(response_scale),
None,
),
LocationScaleResponse::Binomial {
link,
noise_transform,
} => {
let likelihood = inverse_link_to_binomial_spec(&link).map_err(|e| {
format!("failed to resolve LikelihoodSpec for binomial location-scale link {link:?}: {e}")
})?;
(
"binomial-location-scale".to_string(),
likelihood,
Some(link.clone()),
Some(link),
None,
Some(noise_transform),
)
}
};
let mut payload = FittedModelPayload::new(
MODEL_PAYLOAD_VERSION,
inputs.formula,
ModelKind::LocationScale,
FittedFamily::LocationScale {
likelihood,
base_link,
},
family_tag,
);
payload.unified = Some(inputs.fit_result.clone());
payload.fit_result = Some(inputs.fit_result);
payload.data_schema = Some(inputs.data_schema);
payload.link = link;
payload.formula_noise = Some(inputs.noise_formula);
payload.beta_noise = inputs.beta_noise;
payload.gaussian_response_scale = response_scale;
if let Some(transform) = noise_transform {
payload.noise_projection = Some(
transform
.projection_coef
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect(),
);
payload.noise_center = Some(transform.weighted_column_mean.to_vec());
payload.noise_scale = Some(transform.rescale.to_vec());
payload.noise_non_intercept_start = Some(transform.non_intercept_start);
payload.noise_projection_ridge_alpha = Some(transform.projection_ridge_alpha);
}
payload.resolved_termspec = Some(inputs.resolved_termspec);
payload.resolved_termspec_noise = Some(inputs.resolved_termspec_noise);
if let Some(wiggle) = inputs.wiggle {
payload.linkwiggle_knots = Some(wiggle.knots);
payload.linkwiggle_degree = Some(wiggle.degree);
payload.beta_link_wiggle = Some(wiggle.beta_link_wiggle);
}
source.apply_to(&mut payload);
Ok(payload)
}
pub struct SurvivalMarginalSlopeInputs<'a> {
pub formula: String,
pub data_schema: DataSchema,
pub fit_result: UnifiedFitResult,
pub frailty: crate::families::lognormal_kernel::FrailtySpec,
pub survival_entry: Option<String>,
pub survival_exit: String,
pub survival_event: String,
pub survivalspec: String,
pub baseline_cfg: SurvivalBaselineConfig,
pub time_basis: SavedSurvivalTimeBasis,
pub ridge_lambda: f64,
pub survival_likelihood_label: String,
pub resolved_marginalspec: TermCollectionSpec,
pub resolved_logslopespec: TermCollectionSpec,
pub logslope_formula: String,
pub z_column: String,
pub latent_z_normalization: SavedLatentZNormalization,
pub baseline_logslope: f64,
pub score_warp_runtime: Option<&'a DeviationRuntime>,
pub link_dev_runtime: Option<&'a DeviationRuntime>,
pub influence_absorber_width: Option<usize>,
}
pub fn assemble_survival_marginal_slope_payload(
inputs: SurvivalMarginalSlopeInputs<'_>,
source: SavedModelSourceMetadata,
) -> FittedModelPayload {
let mut payload = FittedModelPayload::new(
MODEL_PAYLOAD_VERSION,
inputs.formula,
ModelKind::Survival,
FittedFamily::Survival {
likelihood: LikelihoodSpec::new(
ResponseFamily::RoystonParmar,
InverseLink::Standard(StandardLink::Identity),
),
survival_likelihood: Some(inputs.survival_likelihood_label.clone()),
survival_distribution: Some(ResidualDistribution::Gaussian),
frailty: inputs.frailty,
},
ResponseFamily::RoystonParmar.name().to_string(),
);
payload.unified = Some(inputs.fit_result.clone());
payload.fit_result = Some(inputs.fit_result);
payload.data_schema = Some(inputs.data_schema);
payload.survival_entry = inputs.survival_entry;
payload.survival_exit = Some(inputs.survival_exit);
payload.survival_event = Some(inputs.survival_event);
payload.survivalspec = Some(inputs.survivalspec);
payload.survival_baseline_target =
Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
payload.survival_baseline_scale = inputs.baseline_cfg.scale;
payload.survival_baseline_shape = inputs.baseline_cfg.shape;
payload.survival_baseline_rate = inputs.baseline_cfg.rate;
payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
payload.apply_survival_time_basis(&inputs.time_basis);
payload.survivalridge_lambda = Some(inputs.ridge_lambda);
payload.survival_likelihood = Some(inputs.survival_likelihood_label);
payload.survival_distribution = Some(ResidualDistribution::Gaussian);
payload.link = Some(InverseLink::Standard(StandardLink::Probit));
payload.resolved_termspec = Some(inputs.resolved_marginalspec);
payload.resolved_termspec_logslopes = Some(vec![inputs.resolved_logslopespec.clone()]);
payload.resolved_termspec_logslope = Some(inputs.resolved_logslopespec);
payload.formula_logslope = Some(inputs.logslope_formula.clone());
payload.formula_logslopes = Some(vec![inputs.logslope_formula]);
payload.z_column = Some(inputs.z_column.clone());
payload.z_columns = Some(vec![inputs.z_column]);
payload.latent_z_normalization = Some(inputs.latent_z_normalization);
payload.latent_measure = Some(LatentMeasureKind::StandardNormal);
payload.logslope_baseline = Some(inputs.baseline_logslope);
payload.logslope_baselines = Some(vec![inputs.baseline_logslope]);
payload.score_warp_runtime = inputs
.score_warp_runtime
.map(serialize_anchored_deviation_runtime);
payload.link_deviation_runtime = inputs
.link_dev_runtime
.map(serialize_anchored_deviation_runtime);
payload.influence_absorber_width = inputs.influence_absorber_width;
source.apply_to(&mut payload);
payload
}
pub enum SurvivalTimewiggleBeta {
Single(Vec<f64>),
ByCause(Vec<Vec<f64>>),
}
pub struct SurvivalTimewiggle {
pub degree: usize,
pub knots: Vec<f64>,
pub penalty_orders: Option<Vec<usize>>,
pub double_penalty: Option<bool>,
pub beta: SurvivalTimewiggleBeta,
}
pub struct SurvivalTransformationInputs {
pub formula: String,
pub data_schema: DataSchema,
pub fit_result: UnifiedFitResult,
pub survival_entry: Option<String>,
pub survival_exit: String,
pub survival_event: String,
pub survivalspec: String,
pub cause_count: Option<usize>,
pub baseline_cfg: SurvivalBaselineConfig,
pub time_basis: SavedSurvivalTimeBasis,
pub ridge_lambda: f64,
pub survival_likelihood_label: String,
pub resolved_termspec: TermCollectionSpec,
pub survival_beta_time: Option<Vec<f64>>,
pub timewiggle: Option<SurvivalTimewiggle>,
}
pub fn assemble_survival_transformation_payload(
inputs: SurvivalTransformationInputs,
source: SavedModelSourceMetadata,
) -> FittedModelPayload {
let mut payload = FittedModelPayload::new(
MODEL_PAYLOAD_VERSION,
inputs.formula,
ModelKind::Survival,
FittedFamily::Survival {
likelihood: LikelihoodSpec::new(
ResponseFamily::RoystonParmar,
InverseLink::Standard(StandardLink::Identity),
),
survival_likelihood: Some(inputs.survival_likelihood_label.clone()),
survival_distribution: None,
frailty: crate::families::lognormal_kernel::FrailtySpec::None,
},
ResponseFamily::RoystonParmar.name().to_string(),
);
payload.unified = Some(inputs.fit_result.clone());
payload.fit_result = Some(inputs.fit_result);
payload.data_schema = Some(inputs.data_schema);
payload.survival_entry = inputs.survival_entry;
payload.survival_exit = Some(inputs.survival_exit);
payload.survival_event = Some(inputs.survival_event);
payload.survivalspec = Some(inputs.survivalspec);
if let Some(cause_count) = inputs.cause_count {
payload.survival_cause_count = Some(cause_count);
payload.survival_endpoint_names = Some(
(1..=cause_count)
.map(|idx| format!("cause_{idx}"))
.collect(),
);
}
payload.survival_baseline_target =
Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
payload.survival_baseline_scale = inputs.baseline_cfg.scale;
payload.survival_baseline_shape = inputs.baseline_cfg.shape;
payload.survival_baseline_rate = inputs.baseline_cfg.rate;
payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
payload.apply_survival_time_basis(&inputs.time_basis);
if let Some(timewiggle) = inputs.timewiggle {
payload.baseline_timewiggle_degree = Some(timewiggle.degree);
payload.baseline_timewiggle_knots = Some(timewiggle.knots);
payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
match timewiggle.beta {
SurvivalTimewiggleBeta::Single(beta) => {
payload.beta_baseline_timewiggle = Some(beta);
}
SurvivalTimewiggleBeta::ByCause(by_cause) => {
payload.beta_baseline_timewiggle_by_cause = Some(by_cause);
}
}
}
payload.survivalridge_lambda = Some(inputs.ridge_lambda);
payload.survival_likelihood = Some(inputs.survival_likelihood_label);
payload.survival_beta_time = inputs.survival_beta_time;
payload.resolved_termspec = Some(inputs.resolved_termspec);
source.apply_to(&mut payload);
payload
}
pub struct SurvivalLocationScaleInputs<'a> {
pub formula: String,
pub data_schema: DataSchema,
pub fit_result: UnifiedFitResult,
pub fitted_inverse_link: InverseLink,
pub linkwiggle_degree: Option<usize>,
pub linkwiggle_knots: Option<Vec<f64>>,
pub beta_link_wiggle: Option<Vec<f64>>,
pub baseline_timewiggle: Option<SurvivalTimewiggle>,
pub survival_entry: Option<String>,
pub survival_exit: String,
pub survival_event: String,
pub survivalspec: String,
pub baseline_cfg: SurvivalBaselineConfig,
pub time_basis: SavedSurvivalTimeBasis,
pub ridge_lambda: f64,
pub survival_likelihood_label: String,
pub formula_noise: Option<String>,
pub survival_beta_time: Vec<f64>,
pub survival_beta_threshold: Vec<f64>,
pub survival_beta_log_sigma: Vec<f64>,
pub noise_transform: &'a ScaleDeviationTransform,
pub resolved_thresholdspec: TermCollectionSpec,
pub resolved_log_sigmaspec: TermCollectionSpec,
}
pub fn assemble_survival_location_scale_payload(
inputs: SurvivalLocationScaleInputs<'_>,
source: SavedModelSourceMetadata,
) -> FittedModelPayload {
let survival_distribution =
residual_distribution_from_inverse_link(&inputs.fitted_inverse_link);
let mut payload = FittedModelPayload::new(
MODEL_PAYLOAD_VERSION,
inputs.formula,
ModelKind::Survival,
FittedFamily::Survival {
likelihood: LikelihoodSpec::new(
ResponseFamily::RoystonParmar,
InverseLink::Standard(StandardLink::Identity),
),
survival_likelihood: Some(inputs.survival_likelihood_label.clone()),
survival_distribution,
frailty: crate::families::lognormal_kernel::FrailtySpec::None,
},
ResponseFamily::RoystonParmar.name().to_string(),
);
payload.unified = Some(inputs.fit_result.clone());
payload.fit_result = Some(inputs.fit_result);
payload.data_schema = Some(inputs.data_schema);
payload.link = Some(inputs.fitted_inverse_link);
payload.linkwiggle_degree = inputs.linkwiggle_degree;
payload.linkwiggle_knots = inputs.linkwiggle_knots;
payload.beta_link_wiggle = inputs.beta_link_wiggle;
if let Some(timewiggle) = inputs.baseline_timewiggle {
payload.baseline_timewiggle_degree = Some(timewiggle.degree);
payload.baseline_timewiggle_knots = Some(timewiggle.knots);
payload.baseline_timewiggle_penalty_orders = timewiggle.penalty_orders;
payload.baseline_timewiggle_double_penalty = timewiggle.double_penalty;
if let SurvivalTimewiggleBeta::Single(beta) = timewiggle.beta {
payload.beta_baseline_timewiggle = Some(beta);
}
}
payload.survival_entry = inputs.survival_entry;
payload.survival_exit = Some(inputs.survival_exit);
payload.survival_event = Some(inputs.survival_event);
payload.survivalspec = Some(inputs.survivalspec);
payload.survival_baseline_target =
Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
payload.survival_baseline_scale = inputs.baseline_cfg.scale;
payload.survival_baseline_shape = inputs.baseline_cfg.shape;
payload.survival_baseline_rate = inputs.baseline_cfg.rate;
payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
payload.apply_survival_time_basis(&inputs.time_basis);
payload.survivalridge_lambda = Some(inputs.ridge_lambda);
payload.survival_likelihood = Some(inputs.survival_likelihood_label);
payload.formula_noise = inputs.formula_noise;
payload.survival_beta_time = Some(inputs.survival_beta_time);
payload.survival_beta_threshold = Some(inputs.survival_beta_threshold);
payload.survival_beta_log_sigma = Some(inputs.survival_beta_log_sigma);
payload.survival_noise_projection = Some(
inputs
.noise_transform
.projection_coef
.rows()
.into_iter()
.map(|row| row.to_vec())
.collect(),
);
payload.survival_noise_center = Some(inputs.noise_transform.weighted_column_mean.to_vec());
payload.survival_noise_scale = Some(inputs.noise_transform.rescale.to_vec());
payload.survival_noise_non_intercept_start = Some(inputs.noise_transform.non_intercept_start);
payload.survival_noise_projection_ridge_alpha =
Some(inputs.noise_transform.projection_ridge_alpha);
payload.survival_distribution = survival_distribution;
payload.resolved_termspec = Some(inputs.resolved_thresholdspec);
payload.resolved_termspec_noise = Some(inputs.resolved_log_sigmaspec);
source.apply_to(&mut payload);
payload
}
pub struct LatentWindowInputs {
pub formula: String,
pub data_schema: DataSchema,
pub fit_result: UnifiedFitResult,
pub family: FittedFamily,
pub model_class_label: String,
pub likelihood_label: String,
pub survival_entry: Option<String>,
pub survival_exit: String,
pub survival_event: String,
pub baseline_cfg: SurvivalBaselineConfig,
pub time_basis: SavedSurvivalTimeBasis,
pub ridge_lambda: f64,
pub beta_time: Vec<f64>,
pub resolved_termspec: TermCollectionSpec,
}
pub fn assemble_latent_window_payload(
inputs: LatentWindowInputs,
source: SavedModelSourceMetadata,
) -> FittedModelPayload {
let mut payload = FittedModelPayload::new(
MODEL_PAYLOAD_VERSION,
inputs.formula,
ModelKind::Survival,
inputs.family,
inputs.model_class_label,
);
payload.unified = Some(inputs.fit_result.clone());
payload.fit_result = Some(inputs.fit_result);
payload.data_schema = Some(inputs.data_schema);
payload.survival_entry = inputs.survival_entry;
payload.survival_exit = Some(inputs.survival_exit);
payload.survival_event = Some(inputs.survival_event);
payload.survivalspec = Some("net".to_string());
payload.survival_baseline_target =
Some(survival_baseline_targetname(inputs.baseline_cfg.target).to_string());
payload.survival_baseline_scale = inputs.baseline_cfg.scale;
payload.survival_baseline_shape = inputs.baseline_cfg.shape;
payload.survival_baseline_rate = inputs.baseline_cfg.rate;
payload.survival_baseline_makeham = inputs.baseline_cfg.makeham;
payload.apply_survival_time_basis(&inputs.time_basis);
payload.survival_likelihood = Some(inputs.likelihood_label);
payload.survival_beta_time = Some(inputs.beta_time);
payload.survivalridge_lambda = Some(inputs.ridge_lambda);
payload.resolved_termspec = Some(inputs.resolved_termspec);
source.apply_to(&mut payload);
payload
}