use super::*;
pub(crate) fn reject_marginal_slope_controls_for_transformation_normal(
config: &FitConfig,
) -> Result<(), WorkflowError> {
let family_requests_marginal_slope = config.family.as_deref().is_some_and(|family| {
let canonical = family.to_ascii_lowercase().replace('_', "-");
canonical == "bernoulli-marginal-slope" || canonical == "binary-marginal-slope"
});
if family_requests_marginal_slope
|| config.logslope_formula.is_some()
|| config.z_column.is_some()
|| config.ctn_stage1.is_some()
{
return Err(WorkflowError::InvalidConfig {
reason: "transformation_normal cannot be combined with marginal-slope family controls"
.to_string(),
});
}
Ok(())
}
pub(crate) fn reject_survival_only_terms_for_nonsurvival(
parsed: &ParsedFormula,
) -> Result<(), WorkflowError> {
if parsed.timewiggle.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "timewiggle(...) is only supported in the main survival formula \
(a formula with a Surv(...) response); it is meaningless for a \
non-survival response and would otherwise be silently ignored"
.to_string(),
});
}
if parsed.survivalspec.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "survmodel(...) is only supported in the main survival formula \
(a formula with a Surv(...) response); it is meaningless for a \
non-survival response and would otherwise be silently ignored"
.to_string(),
});
}
Ok(())
}
fn reject_explicit_linkwiggle_for_nonbinomial(
parsed: &ParsedFormula,
family: &LikelihoodSpec,
) -> Result<(), WorkflowError> {
if parsed.linkwiggle.is_some() && !family.is_binomial() {
return Err(WorkflowError::InvalidConfig {
reason: "linkwiggle(...) corrects the link function of a binomial mean model \
and is only supported for a binomial response; it is meaningless for \
the resolved non-binomial family and would otherwise be silently ignored"
.to_string(),
});
}
Ok(())
}
pub fn is_binary_response(y: ArrayView1<'_, f64>) -> bool {
if y.is_empty() {
return false;
}
y.iter()
.all(|v| (*v - 0.0).abs() < 1e-12 || (*v - 1.0).abs() < 1e-12)
}
fn check_smooth_capacity(
spec: &crate::terms::smooth::TermCollectionSpec,
n_rows: usize,
response_name: &str,
) -> Result<(), WorkflowError> {
let mut required: usize = 2;
let mut per_term: Vec<(String, usize)> = Vec::new();
for term in &spec.smooth_terms {
let need = term.basis.min_sample_rows();
required = required.saturating_add(need);
per_term.push((term.name.clone(), need));
}
if per_term.is_empty() || n_rows >= required {
return Ok(());
}
let breakdown = per_term
.iter()
.map(|(name, k)| format!("{name}≥{k}"))
.collect::<Vec<_>>()
.join(", ");
Err(WorkflowError::InvalidConfig {
reason: format!(
"not enough observations to fit the requested formula: dataset has n={n_rows} \
rows but the smooth terms on response '{response_name}' need at least \
{required} rows total ({breakdown}, plus intercept + smoothing-parameter dof) \
before REML estimation is well-posed. \
Fix: add more training rows, replace `s(x)` with a linear term, or pass a \
smaller basis via `s(x, k=3)`."
),
})
}
pub(crate) fn response_column_kind(data: &Dataset, y_col: usize) -> ResponseColumnKind {
match data.column_kinds.get(y_col) {
Some(ColumnKindTag::Categorical) => ResponseColumnKind::Categorical {
levels: data
.schema
.columns
.get(y_col)
.map(|sc| sc.levels.clone())
.unwrap_or_default(),
},
Some(ColumnKindTag::Binary) => ResponseColumnKind::Binary,
Some(ColumnKindTag::Continuous) | None => ResponseColumnKind::Numeric,
}
}
fn link_legal_for_family(response: &ResponseFamily, link: LinkFunction) -> bool {
match response {
ResponseFamily::Gaussian => matches!(link, LinkFunction::Identity),
ResponseFamily::Poisson
| ResponseFamily::Gamma
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => matches!(link, LinkFunction::Log),
ResponseFamily::Beta { .. } => matches!(link, LinkFunction::Logit),
ResponseFamily::Binomial => matches!(
link,
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic
),
ResponseFamily::RoystonParmar => false,
}
}
pub fn resolve_family(
family: Option<&str>,
negative_binomial_theta: Option<f64>,
link_choice: Option<&LinkChoice>,
y: ArrayView1<'_, f64>,
y_kind: ResponseColumnKind,
response_name: &str,
) -> Result<LikelihoodSpec, String> {
let nb_theta = negative_binomial_theta.unwrap_or(1.0);
if !nb_theta.is_finite() || nb_theta <= 0.0 {
return Err(format!(
"negative-binomial theta must be finite and > 0; got {nb_theta}"
));
}
let explicit: Option<(LikelihoodSpec, bool)> = match family {
Some(name) => {
let canonical = name.to_ascii_lowercase().replace('_', "-");
let resolved = match canonical.as_str() {
"gaussian" => (
LikelihoodSpec::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
),
false,
),
"binomial" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
),
false,
),
"binomial-logit" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
),
true,
),
"binomial-probit" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Probit),
),
true,
),
"binomial-cloglog" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::CLogLog),
),
true,
),
"latent-cloglog-binomial" => (
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::LatentCLogLog(
LatentCLogLogState::new(1.0)
.map_err(|err| format!("latent cloglog default state: {err}"))?,
),
),
true,
),
"poisson" => (
LikelihoodSpec::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Log),
),
false,
),
"nb" | "negbin" | "negative-binomial" => (
LikelihoodSpec::new(
ResponseFamily::NegativeBinomial {
theta: nb_theta,
theta_fixed: negative_binomial_theta.is_some(),
},
InverseLink::Standard(StandardLink::Log),
),
false,
),
"negative-binomial-log" => (
LikelihoodSpec::new(
ResponseFamily::NegativeBinomial {
theta: nb_theta,
theta_fixed: negative_binomial_theta.is_some(),
},
InverseLink::Standard(StandardLink::Log),
),
true,
),
"beta" | "beta-regression" => (
LikelihoodSpec::new(
ResponseFamily::Beta { phi: 1.0 },
InverseLink::Standard(StandardLink::Logit),
),
false,
),
"beta-logit" | "beta-regression-logit" => (
LikelihoodSpec::new(
ResponseFamily::Beta { phi: 1.0 },
InverseLink::Standard(StandardLink::Logit),
),
true,
),
"gamma" => (
LikelihoodSpec::new(
ResponseFamily::Gamma,
InverseLink::Standard(StandardLink::Log),
),
false,
),
"royston-parmar" => (LikelihoodSpec::royston_parmar(), true),
"transformation-normal" => (
LikelihoodSpec::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
),
true,
),
"tweedie" | "tw" => (
LikelihoodSpec::new(
ResponseFamily::Tweedie { p: 1.5 },
InverseLink::Standard(StandardLink::Log),
),
false,
),
"tweedie-log" => (
LikelihoodSpec::new(
ResponseFamily::Tweedie { p: 1.5 },
InverseLink::Standard(StandardLink::Log),
),
true,
),
"multinomial" | "multinomial-logit" | "categorical" | "categorical-logit"
| "softmax" => {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"family '{name}' is a vector-response family; use \
the dedicated multinomial entry point \
(`crate::families::multinomial::fit_penalized_multinomial` \
in Rust, or `gamfit.fit_multinomial(...)` in Python) \
rather than the scalar `fit(family=...)` path"
),
}
.into());
}
_ => {
return Err(WorkflowError::InvalidConfig {
reason: format!("unknown family '{name}'"),
}
.into());
}
};
Some(resolved)
}
None => {
if negative_binomial_theta.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "negative_binomial_theta requires family='negative-binomial'"
.to_string(),
}
.into());
}
None
}
};
if let Some(choice) = link_choice {
let from_link: LikelihoodSpec = if let Some(components) = choice.mixture_components.as_ref()
{
let n = components.len();
let free = n.saturating_sub(1);
let mix_spec = MixtureLinkSpec {
components: components.clone(),
initial_rho: Array1::<f64>::zeros(free),
};
let state = state_fromspec(&mix_spec)
.map_err(|err| format!("mixture link initial state: {err}"))?;
LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Mixture(state))
} else {
match choice.link {
LinkFunction::Identity => LikelihoodSpec::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
),
LinkFunction::Log => {
if y.iter()
.all(|&yi| yi.is_finite() && yi >= 0.0 && (yi - yi.round()).abs() <= 1e-9)
{
LikelihoodSpec::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Log),
)
} else {
LikelihoodSpec::new(
ResponseFamily::Gamma,
InverseLink::Standard(StandardLink::Log),
)
}
}
LinkFunction::Logit => LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
),
LinkFunction::Probit => LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Probit),
),
LinkFunction::CLogLog => LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::CLogLog),
),
LinkFunction::Sas => {
let state = state_from_sasspec(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
})
.map_err(|err| format!("SAS link initial state: {err}"))?;
LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::Sas(state))
}
LinkFunction::BetaLogistic => {
let state = state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: 0.0,
initial_log_delta: 0.0,
})
.map_err(|err| format!("Beta-Logistic link initial state: {err}"))?;
LikelihoodSpec::new(ResponseFamily::Binomial, InverseLink::BetaLogistic(state))
}
}
};
if let Some((explicit_spec, link_pinned)) = explicit.as_ref() {
let mixture_requested = choice.mixture_components.is_some();
let legal = if mixture_requested {
matches!(explicit_spec.response, ResponseFamily::Binomial)
} else {
link_legal_for_family(&explicit_spec.response, choice.link)
};
if !legal {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"link '{}' is not supported for family '{}'",
choice.link.name(),
explicit_spec.response.name()
),
}
.into());
}
if *link_pinned && explicit_spec.link.link_function() != from_link.link.link_function()
{
return Err(WorkflowError::InvalidConfig {
reason: format!(
"family '{}' pins link '{}', which conflicts with requested link '{}'",
explicit_spec.name(),
explicit_spec.link.link_function().name(),
choice.link.name(),
),
}
.into());
}
return Ok(LikelihoodSpec::new(
explicit_spec.response.clone(),
from_link.link,
));
}
return Ok(from_link);
}
if let Some((spec, _)) = explicit {
return Ok(spec);
}
let response = ResponseFamily::infer_from_response(y, y_kind).map_err(|refusal| {
let err: String = WorkflowError::InvalidConfig {
reason: refusal.message_for(response_name),
}
.into();
err
})?;
let link = match response {
ResponseFamily::Binomial => InverseLink::Standard(StandardLink::Logit),
ResponseFamily::Poisson => InverseLink::Standard(StandardLink::Log),
_ => InverseLink::Standard(StandardLink::Identity),
};
Ok(LikelihoodSpec::new(response, link))
}
pub(crate) fn build_termspec_with_geometry_and_overrides(
terms: &[ParsedTerm],
data: &Dataset,
col_map: &HashMap<String, usize>,
inference_notes: &mut Vec<String>,
scale_dimensions: bool,
policy: &crate::resource::ResourcePolicy,
smooth_overrides: Option<&JsonValue>,
) -> Result<TermCollectionSpec, WorkflowError> {
let mut spec = build_termspec(terms, data, col_map, inference_notes, policy)?;
if scale_dimensions {
enable_scale_dimensions(&mut spec);
}
if let Some(overrides) = smooth_overrides {
crate::terms::smooth_overrides::apply_smooth_overrides(
&mut spec,
overrides,
data,
inference_notes,
)
.map_err(|reason| WorkflowError::InvalidConfig { reason })?;
}
Ok(spec)
}
fn linear_term_training_column(
data: &Dataset,
term: &LinearTermSpec,
) -> Result<Array1<f64>, WorkflowError> {
let cols = term.effective_feature_cols();
if cols.is_empty() {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"linear term '{}' has no feature columns; cannot build its training column",
term.name
),
});
}
let n = data.values.nrows();
let mut out = Array1::<f64>::ones(n);
for &col in &cols {
if col >= data.values.ncols() {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"linear term '{}' feature column {} out of bounds for {} columns",
term.name,
col,
data.values.ncols()
),
}
.into());
}
for row in 0..n {
out[row] *= data.values[[row, col]];
}
}
Ok(out)
}
fn residualize_against_orthonormal_basis(
column: &Array1<f64>,
basis: &[Array1<f64>],
) -> Array1<f64> {
let mut residual = column.clone();
for q in basis {
let coeff = residual.dot(q);
residual.scaled_add(-coeff, q);
}
residual
}
fn l2_norm(column: &Array1<f64>) -> f64 {
column.iter().map(|v| v * v).sum::<f64>().sqrt()
}
pub(crate) fn prune_unidentified_linear_terms_for_marginal_slope(
spec: &mut TermCollectionSpec,
data: &Dataset,
label: &str,
inference_notes: &mut Vec<String>,
) -> Result<(), WorkflowError> {
if spec.linear_terms.is_empty() {
return Ok(());
}
let n = data.values.nrows();
if n == 0 {
return Err(WorkflowError::InvalidConfig {
reason: format!("{label}: cannot rank-check scalar terms on zero rows"),
});
}
let mut basis = Vec::<Array1<f64>>::new();
let intercept = Array1::<f64>::ones(n);
let intercept_norm = l2_norm(&intercept);
if intercept_norm == 0.0 || !intercept_norm.is_finite() {
return Err(WorkflowError::InvalidConfig {
reason: format!("{label}: implicit intercept has invalid norm {intercept_norm}"),
});
}
basis.push(intercept.mapv(|v| v / intercept_norm));
let rank_alpha = crate::linalg::faer_ndarray::default_rrqr_rank_alpha();
let mut scale = intercept_norm.max(1.0);
let mut kept = Vec::<LinearTermSpec>::with_capacity(spec.linear_terms.len());
let mut dropped = Vec::<String>::new();
for term in &spec.linear_terms {
let column = linear_term_training_column(data, term)?;
let norm = l2_norm(&column);
if !norm.is_finite() {
return Err(WorkflowError::InvalidConfig {
reason: format!("{label}: linear term '{}' has non-finite norm", term.name),
});
}
scale = scale.max(norm.max(1.0));
let residual = residualize_against_orthonormal_basis(&column, &basis);
let residual_norm = l2_norm(&residual);
let tol = rank_alpha * f64::EPSILON * ((n + basis.len() + 1).max(1) as f64) * scale;
let is_data_redundant = residual_norm <= tol;
let has_constraints = term.coefficient_min.is_some() || term.coefficient_max.is_some();
if is_data_redundant {
if has_constraints {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"{label}: constrained linear term '{}' is redundant with the implicit \
intercept or earlier scalar terms; remove the constraint or the \
redundant term",
term.name
),
});
}
if term.double_penalty {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"{label}: explicitly penalized linear term '{}' is redundant with the \
implicit intercept or earlier scalar terms; remove the redundant term \
instead of relying on a ridge to identify a duplicate data direction",
term.name
),
});
}
dropped.push(format!(
"{} (residual_norm={:.3e}, tol={:.3e})",
term.name, residual_norm, tol
));
continue;
}
if residual_norm > tol {
basis.push(residual.mapv(|v| v / residual_norm));
}
kept.push(term.clone());
}
if !dropped.is_empty() {
inference_notes.push(format!(
"{label}: removed {} scalar term(s) that add no identifiable \
direction beyond the implicit intercept and earlier scalar terms: {}",
dropped.len(),
dropped.join(", ")
));
spec.linear_terms = kept;
}
Ok(())
}
fn standard_adaptive_regularization_options(
config: &FitConfig,
) -> Option<AdaptiveRegularizationOptions> {
let enabled = config.adaptive_regularization.unwrap_or(false);
enabled.then(|| AdaptiveRegularizationOptions {
enabled: true,
..AdaptiveRegularizationOptions::default()
})
}
fn resolve_survival_marginal_slope_base_link(
linkspec: Option<&crate::inference::formula_dsl::LinkFormulaSpec>,
) -> Result<InverseLink, String> {
let Some(linkspec) = linkspec else {
return Ok(InverseLink::Standard(StandardLink::Probit));
};
let choice = parse_link_choice(Some(&linkspec.link), false)?
.ok_or_else(|| "invalid survival marginal-slope link".to_string())?;
if choice.mixture_components.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "survival marginal-slope currently supports only link(type=probit)".to_string(),
}
.into());
}
match choice.link {
LinkFunction::Probit => Ok(InverseLink::Standard(StandardLink::Probit)),
other => Err(WorkflowError::InvalidConfig {
reason: format!(
"survival marginal-slope currently supports only link(type=probit), got {other:?}"
),
}
.into()),
}
}
pub struct PreparedSurvivalTimeStack {
pub eta_offset_entry: Array1<f64>,
pub eta_offset_exit: Array1<f64>,
pub derivative_offset_exit: Array1<f64>,
pub unloaded_mass_entry: Array1<f64>,
pub unloaded_mass_exit: Array1<f64>,
pub unloaded_hazard_exit: Array1<f64>,
pub time_design_entry: crate::matrix::DesignMatrix,
pub time_design_exit: crate::matrix::DesignMatrix,
pub time_design_derivative_exit: crate::matrix::DesignMatrix,
pub time_penalties: Vec<Array2<f64>>,
pub time_nullspace_dims: Vec<usize>,
pub timewiggle_build: Option<crate::families::survival_construction::SurvivalTimeWiggleBuild>,
pub timewiggle_block: Option<TimeWiggleBlockInput>,
}
pub fn prepare_survival_time_stack(
age_entry: &Array1<f64>,
age_exit: &Array1<f64>,
baseline_cfg: &crate::families::survival_construction::SurvivalBaselineConfig,
likelihood_mode: SurvivalLikelihoodMode,
inverse_link: Option<&InverseLink>,
time_anchor: f64,
derivative_guard: f64,
time_build: &crate::families::survival_construction::SurvivalTimeBuildOutput,
effective_timewiggle: Option<&LinkWiggleFormulaSpec>,
latent_loading: Option<crate::families::lognormal_kernel::HazardLoading>,
) -> Result<PreparedSurvivalTimeStack, String> {
let (
mut eta_offset_entry,
mut eta_offset_exit,
mut derivative_offset_exit,
unloaded_mass_entry,
unloaded_mass_exit,
unloaded_hazard_exit,
) = if let Some(loading) = latent_loading {
let offsets =
build_latent_survival_baseline_offsets(age_entry, age_exit, baseline_cfg, loading)?;
(
offsets.loaded_eta_entry,
offsets.loaded_eta_exit,
offsets.loaded_derivative_exit,
offsets.unloaded_mass_entry,
offsets.unloaded_mass_exit,
offsets.unloaded_hazard_exit,
)
} else {
let conditioning_cfg;
let offset_cfg = if likelihood_mode == SurvivalLikelihoodMode::MarginalSlope
&& baseline_cfg.target == SurvivalBaselineTarget::Linear
{
let scale =
crate::families::survival_construction::positive_survival_time_seed(age_exit);
conditioning_cfg = crate::families::survival_construction::SurvivalBaselineConfig {
target: SurvivalBaselineTarget::Weibull,
scale: Some(scale),
shape: Some(1.0),
rate: None,
makeham: None,
};
&conditioning_cfg
} else {
baseline_cfg
};
let (eta_offset_entry, eta_offset_exit, derivative_offset_exit) =
build_survival_time_offsets_for_likelihood(
age_entry,
age_exit,
offset_cfg,
likelihood_mode,
inverse_link,
)?;
let n = age_entry.len();
(
eta_offset_entry,
eta_offset_exit,
derivative_offset_exit,
Array1::zeros(n),
Array1::zeros(n),
Array1::zeros(n),
)
};
add_survival_time_derivative_guard_offset(
age_entry,
age_exit,
time_anchor,
derivative_guard,
&mut eta_offset_entry,
&mut eta_offset_exit,
&mut derivative_offset_exit,
)?;
let timewiggle_build = if let Some(cfg) = effective_timewiggle {
Some(build_survival_timewiggle_from_baseline(
&eta_offset_entry,
&eta_offset_exit,
&derivative_offset_exit,
cfg,
)?)
} else {
None
};
let mut time_design_entry = time_build.x_entry_time.clone();
let mut time_design_exit = time_build.x_exit_time.clone();
let mut time_design_derivative_exit = time_build.x_derivative_time.clone();
let mut time_penalties = time_build.penalties.clone();
let mut time_nullspace_dims = time_build.nullspace_dims.clone();
let mut timewiggle_block = None;
if let Some(wiggle) = timewiggle_build.as_ref() {
let p_base = time_design_exit.ncols();
append_zero_tail_columns(
&mut time_design_entry,
&mut time_design_exit,
&mut time_design_derivative_exit,
wiggle.ncols,
);
for (idx, penalty) in wiggle.penalties.iter().enumerate() {
let mut embedded = Array2::<f64>::zeros((p_base + wiggle.ncols, p_base + wiggle.ncols));
embedded
.slice_mut(s![
p_base..p_base + wiggle.ncols,
p_base..p_base + wiggle.ncols
])
.assign(penalty);
time_penalties.push(embedded);
time_nullspace_dims.push(wiggle.nullspace_dims.get(idx).copied().unwrap_or(0));
}
timewiggle_block = Some(TimeWiggleBlockInput {
knots: wiggle.knots.clone(),
degree: wiggle.degree,
ncols: wiggle.ncols,
});
}
Ok(PreparedSurvivalTimeStack {
eta_offset_entry,
eta_offset_exit,
derivative_offset_exit,
unloaded_mass_entry,
unloaded_mass_exit,
unloaded_hazard_exit,
time_design_entry,
time_design_exit,
time_design_derivative_exit,
time_penalties,
time_nullspace_dims,
timewiggle_build,
timewiggle_block,
})
}
fn resolve_continuous_column(
data: &Dataset,
col_map: &HashMap<String, usize>,
column_name: &str,
role: &str,
) -> Result<Array1<f64>, WorkflowError> {
let col_idx = resolve_role_col(col_map, column_name, role)?;
let values = data.values.column(col_idx).to_owned();
for (row_idx, value) in values.iter().enumerate() {
if !value.is_finite() {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"{role} column '{column_name}' contains non-finite value at row {row_idx}: {value}"
),
});
}
}
Ok(values)
}
pub fn resolve_offset_column(
data: &Dataset,
col_map: &HashMap<String, usize>,
column_name: Option<&str>,
) -> Result<Array1<f64>, WorkflowError> {
let Some(column_name) = column_name else {
return Ok(Array1::zeros(data.values.nrows()));
};
resolve_continuous_column(data, col_map, column_name, "offset")
}
pub fn resolve_weight_column(
data: &Dataset,
col_map: &HashMap<String, usize>,
column_name: Option<&str>,
) -> Result<Array1<f64>, WorkflowError> {
let Some(column_name) = column_name else {
return Ok(Array1::ones(data.values.nrows()));
};
let values = resolve_continuous_column(data, col_map, column_name, "weights")?;
for (row_idx, value) in values.iter().enumerate() {
if *value < 0.0 {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"weights column '{column_name}' must be non-negative; found {value} at row {row_idx}"
),
});
}
}
Ok(values)
}
const MARGINAL_SLOPE_Z_WEIGHTED_SD_FLOOR: f64 = 1e-12;
fn validate_bernoulli_marginal_slope_z_column_variance(
z_column: &str,
z: ArrayView1<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Result<(), WorkflowError> {
if z.len() != weights.len() {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"z_column '{z_column}' length mismatch for bernoulli-marginal-slope: z={}, weights={}",
z.len(),
weights.len()
),
});
}
let n = z.len();
let weight_sum = weights.iter().copied().sum::<f64>();
if !(weight_sum.is_finite() && weight_sum > 0.0) {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"z_column '{z_column}' cannot be weighted for bernoulli-marginal-slope because the fit data have non-positive or non-finite total weight"
),
});
}
let mean = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * zi)
.sum::<f64>()
/ weight_sum;
let var = z
.iter()
.zip(weights.iter())
.map(|(&zi, &wi)| wi * (zi - mean) * (zi - mean))
.sum::<f64>()
/ weight_sum;
let weighted_sd = var.sqrt();
if weighted_sd.is_finite() && weighted_sd > MARGINAL_SLOPE_Z_WEIGHTED_SD_FLOOR {
return Ok(());
}
let mut sorted = z.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
sorted.dedup_by(|a, b| (*a - *b).abs() <= MARGINAL_SLOPE_Z_WEIGHTED_SD_FLOOR);
let unique_count = sorted.len();
let value_summary = match sorted.as_slice() {
[] => "no observed finite values".to_string(),
[only] => format!("all {n} values ~= {only:.6}"),
[first, second] => {
format!("{unique_count} near-unique values, e.g. {first:.6}, {second:.6}")
}
[first, second, ..] => {
format!("{unique_count} near-unique values, e.g. {first:.6}, {second:.6}, ...")
}
};
Err(WorkflowError::InvalidConfig {
reason: format!(
"z_column '{z_column}' has zero weighted variance on the fit data ({value_summary}; weighted_sd={weighted_sd:.6e}, n={n}); bernoulli-marginal-slope cannot identify a covariate-varying slope from a constant score. Check the score column and fit population."
),
})
}
#[derive(Clone)]
enum LatentInitSpec {
Pca,
Random,
Explicit(Array2<f64>),
}
#[derive(Clone)]
struct LatentAuxPriorSpec {
u: Array2<f64>,
family: AuxPriorFamily,
strength: AuxPriorStrength,
}
#[derive(Clone)]
struct LatentDimSelectionSpec {
init_log_precision: Option<Array1<f64>>,
}
#[derive(Clone)]
struct LatentAuxOutcomeSpec {
family: crate::terms::behavioral_head::AuxOutcomeFamily,
y: Array1<f64>,
row_weights: Option<Array1<f64>>,
init_log_precision: Option<Array1<f64>>,
}
#[derive(Clone)]
struct LatentManifoldSpec {
manifold: LatentManifold,
auto: bool,
}
#[derive(Clone)]
struct LatentSpec {
target: String,
n: usize,
d: usize,
init: LatentInitSpec,
manifold: LatentManifoldSpec,
retraction_registry: LatentRetractionRegistry,
aux_prior: Option<LatentAuxPriorSpec>,
dim_selection: Option<LatentDimSelectionSpec>,
aux_outcome: Option<LatentAuxOutcomeSpec>,
explicit_none_mode: bool,
}
fn json_array2(value: &JsonValue, context: &str) -> Result<Array2<f64>, String> {
let rows = value
.as_array()
.ok_or_else(|| format!("{context} must be a two-dimensional numeric array"))?;
let n = rows.len();
let first = rows
.first()
.and_then(|row| row.as_array())
.ok_or_else(|| format!("{context} must contain array rows"))?;
let d = first.len();
let mut out = Array2::<f64>::zeros((n, d));
for (i, row_value) in rows.iter().enumerate() {
let row = row_value
.as_array()
.ok_or_else(|| format!("{context} row {i} must be an array"))?;
if row.len() != d {
return Err(format!(
"{context} row {i} has length {}, expected {d}",
row.len()
));
}
for (j, cell) in row.iter().enumerate() {
let value = cell
.as_f64()
.ok_or_else(|| format!("{context}[{i}][{j}] must be a finite number"))?;
if !value.is_finite() {
return Err(format!("{context}[{i}][{j}] must be finite"));
}
out[[i, j]] = value;
}
}
Ok(out)
}
fn json_array1(value: &JsonValue, context: &str) -> Result<Array1<f64>, String> {
let values = value
.as_array()
.ok_or_else(|| format!("{context} must be a numeric array"))?;
let mut out = Array1::<f64>::zeros(values.len());
for (idx, cell) in values.iter().enumerate() {
let value = cell
.as_f64()
.ok_or_else(|| format!("{context}[{idx}] must be a finite number"))?;
if !value.is_finite() {
return Err(format!("{context}[{idx}] must be finite"));
}
out[idx] = value;
}
Ok(out)
}
fn parse_latent_manifold(
value: Option<&JsonValue>,
d: usize,
context: &str,
) -> Result<LatentManifoldSpec, String> {
let Some(value) = value.filter(|value| !value.is_null()) else {
return Ok(LatentManifoldSpec {
manifold: LatentManifold::Euclidean,
auto: true,
});
};
if value
.as_str()
.is_some_and(|s| s.eq_ignore_ascii_case("auto"))
{
return Ok(LatentManifoldSpec {
manifold: LatentManifold::Euclidean,
auto: true,
});
}
let parse_named = |name: &str| -> Result<LatentManifold, String> {
match name.to_ascii_lowercase().as_str() {
"euclidean" | "r" | "real" => Ok(LatentManifold::Euclidean),
"circle" | "s1" | "periodic" => {
let radians = LatentManifold::Circle {
period: std::f64::consts::TAU,
};
if d == 1 {
Ok(radians)
} else {
Ok(LatentManifold::Product(
(0..d).map(|_| radians.clone()).collect(),
))
}
}
"sphere" | "sn" => Ok(LatentManifold::Sphere { dim: d }),
"torus" => Ok(LatentManifold::Product(
(0..d)
.map(|_| LatentManifold::Circle {
period: std::f64::consts::TAU,
})
.collect(),
)),
"cylinder" => {
if d < 2 {
return Err(format!("{context}='cylinder' requires d >= 2"));
}
let mut parts = Vec::with_capacity(d);
parts.push(LatentManifold::Circle {
period: std::f64::consts::TAU,
});
for _ in 1..d {
parts.push(LatentManifold::Euclidean);
}
Ok(LatentManifold::Product(parts))
}
other => Err(format!(
"{context} must be 'auto', 'euclidean', 'circle', 'sphere', 'torus', or 'cylinder'; got '{other}'"
)),
}
};
let manifold = if let Some(name) = value.as_str() {
parse_named(name)?
} else if let Some(obj) = value.as_object() {
let kind = obj
.get("type")
.or_else(|| obj.get("kind"))
.and_then(JsonValue::as_str)
.unwrap_or("euclidean");
match kind.to_ascii_lowercase().as_str() {
"auto" => {
return Ok(LatentManifoldSpec {
manifold: LatentManifold::Euclidean,
auto: true,
});
}
"interval" => {
let lo = obj
.get("lo")
.or_else(|| obj.get("min"))
.and_then(JsonValue::as_f64)
.ok_or_else(|| format!("{context}.lo is required for interval"))?;
let hi = obj
.get("hi")
.or_else(|| obj.get("max"))
.and_then(JsonValue::as_f64)
.ok_or_else(|| format!("{context}.hi is required for interval"))?;
if !(lo.is_finite() && hi.is_finite() && lo < hi) {
return Err(format!("{context} interval requires finite lo < hi"));
}
LatentManifold::Interval { lo, hi }
}
other => parse_named(other)?,
}
} else if let Some(items) = value.as_array() {
let mut parts = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
parts
.push(parse_latent_manifold(Some(item), 1, &format!("{context}[{idx}]"))?.manifold);
}
LatentManifold::Product(parts)
} else {
return Err(format!(
"{context} must be a string, object, or product array"
));
};
if manifold.ambient_dim(d) != d {
return Err(format!(
"{context} ambient dimension {} does not match latent d={d}",
manifold.ambient_dim(d)
));
}
Ok(LatentManifoldSpec {
manifold,
auto: false,
})
}
fn parse_retraction_kind(
value: &JsonValue,
fallback_dim: usize,
context: &str,
) -> Result<RetractionKind, String> {
let parse_named = |name: &str| -> Result<RetractionKind, String> {
match name.to_ascii_lowercase().as_str() {
"euclidean" | "r" | "real" => Ok(RetractionKind::euclidean(fallback_dim)),
"circle" | "s1" | "periodic" => {
if fallback_dim == 1 {
Ok(RetractionKind::Circle)
} else {
Ok(RetractionKind::Product(ProductRetraction {
parts: (0..fallback_dim).map(|_| RetractionKind::Circle).collect(),
}))
}
}
"sphere" | "sn" => Ok(RetractionKind::Sphere { dim: fallback_dim }),
other => Err(format!(
"{context} must be 'euclidean', 'circle', 'sphere', or a product; got '{other}'"
)),
}
};
if let Some(name) = value.as_str() {
return parse_named(name);
}
if let Some(items) = value.as_array() {
let mut parts = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
parts.push(parse_retraction_kind(
item,
1,
&format!("{context}[{idx}]"),
)?);
}
return Ok(RetractionKind::Product(ProductRetraction { parts }));
}
let obj = value
.as_object()
.ok_or_else(|| format!("{context} must be a string, object, or product array"))?;
let kind = obj
.get("type")
.or_else(|| obj.get("kind"))
.and_then(JsonValue::as_str)
.unwrap_or("euclidean");
match kind.to_ascii_lowercase().as_str() {
"euclidean" | "r" | "real" => {
let dim = obj
.get("dim")
.or_else(|| obj.get("d"))
.and_then(JsonValue::as_u64)
.map_or(fallback_dim, |value| value as usize);
if dim == 0 {
return Err(format!("{context}.dim must be positive"));
}
Ok(RetractionKind::euclidean(dim))
}
"circle" | "s1" | "periodic" => Ok(RetractionKind::Circle),
"sphere" | "sn" => {
let dim = obj
.get("dim")
.or_else(|| obj.get("d"))
.and_then(JsonValue::as_u64)
.map_or(fallback_dim, |value| value as usize);
if dim == 0 {
return Err(format!("{context}.dim must be positive"));
}
Ok(RetractionKind::Sphere { dim })
}
"product" => {
let items = obj
.get("parts")
.or_else(|| obj.get("components"))
.and_then(JsonValue::as_array)
.ok_or_else(|| format!("{context}.parts is required for product retraction"))?;
let mut parts = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
parts.push(parse_retraction_kind(
item,
1,
&format!("{context}.parts[{idx}]"),
)?);
}
Ok(RetractionKind::Product(ProductRetraction { parts }))
}
other => parse_named(other),
}
}
fn parse_latent_retraction(
value: Option<&JsonValue>,
d: usize,
context: &str,
) -> Result<LatentRetractionRegistry, String> {
let Some(value) = value.filter(|value| !value.is_null()) else {
return Ok(LatentRetractionRegistry::all_euclidean());
};
let kind = parse_retraction_kind(value, d, context)?;
let registry = LatentRetractionRegistry::new(kind);
registry.validate_dim(d, context)?;
Ok(registry)
}
fn parse_latent_specs(payload: Option<&JsonValue>) -> Result<Vec<LatentSpec>, String> {
let Some(payload) = payload.filter(|value| !value.is_null()) else {
return Ok(Vec::new());
};
let map = payload
.as_object()
.ok_or_else(|| "latents must be a JSON object keyed by formula symbol".to_string())?;
let mut specs = Vec::with_capacity(map.len());
for (key, raw) in map {
let obj = raw
.as_object()
.ok_or_else(|| format!("latents['{key}'] must be an object"))?;
let target = obj
.get("name")
.and_then(JsonValue::as_str)
.unwrap_or(key)
.to_string();
let n = obj
.get("n")
.and_then(JsonValue::as_u64)
.ok_or_else(|| format!("latents['{key}'].n is required"))? as usize;
let d = obj
.get("d")
.and_then(JsonValue::as_u64)
.ok_or_else(|| format!("latents['{key}'].d is required"))? as usize;
if n == 0 || d == 0 {
return Err(format!("latents['{key}'] requires positive n and d"));
}
let manifold = parse_latent_manifold(
obj.get("manifold"),
d,
&format!("latents['{key}'].manifold"),
)?;
let retraction_registry = parse_latent_retraction(
obj.get("retraction"),
d,
&format!("latents['{key}'].retraction"),
)?;
let init = match obj.get("init") {
None => LatentInitSpec::Pca,
Some(value)
if value
.as_str()
.is_some_and(|s| s.eq_ignore_ascii_case("pca")) =>
{
LatentInitSpec::Pca
}
Some(value)
if value
.as_str()
.is_some_and(|s| s.eq_ignore_ascii_case("random")) =>
{
LatentInitSpec::Random
}
Some(value) => {
LatentInitSpec::Explicit(json_array2(value, &format!("latents['{key}'].init"))?)
}
};
let aux_prior = match obj.get("aux_prior").filter(|value| !value.is_null()) {
None => None,
Some(value) => {
let aux = value
.as_object()
.ok_or_else(|| format!("latents['{key}'].aux_prior must be an object"))?;
let u = json_array2(
aux.get("u")
.ok_or_else(|| format!("latents['{key}'].aux_prior.u is required"))?,
&format!("latents['{key}'].aux_prior.u"),
)?;
let family = match aux
.get("family")
.and_then(JsonValue::as_str)
.unwrap_or("ridge")
.to_ascii_lowercase()
.as_str()
{
"ridge" => AuxPriorFamily::Ridge,
"linear" => AuxPriorFamily::Linear,
other => {
return Err(format!(
"latents['{key}'].aux_prior.family must be 'ridge' or 'linear', got '{other}'"
));
}
};
let strength = match aux.get("strength") {
None => AuxPriorStrength::Fixed(1.0),
Some(value)
if value
.as_str()
.is_some_and(|s| s.eq_ignore_ascii_case("auto")) =>
{
AuxPriorStrength::Auto
}
Some(value) => {
let mu = value.as_f64().ok_or_else(|| {
format!(
"latents['{key}'].aux_prior.strength must be positive or 'auto'"
)
})?;
if !mu.is_finite() || mu <= 0.0 {
return Err(format!(
"latents['{key}'].aux_prior.strength must be positive"
));
}
AuxPriorStrength::Fixed(mu)
}
};
Some(LatentAuxPriorSpec {
u,
family,
strength,
})
}
};
let dim_selection = match obj.get("dim_selection") {
None | Some(JsonValue::Bool(false)) => None,
Some(JsonValue::Bool(true)) => Some(LatentDimSelectionSpec {
init_log_precision: None,
}),
Some(value) => {
let dim = value.as_object().ok_or_else(|| {
format!("latents['{key}'].dim_selection must be a bool or object")
})?;
let init_log_precision = dim
.get("init_log_precision")
.map(|value| {
json_array1(
value,
&format!("latents['{key}'].dim_selection.init_log_precision"),
)
})
.transpose()?;
Some(LatentDimSelectionSpec { init_log_precision })
}
};
let aux_outcome = match obj.get("aux_outcome").filter(|value| !value.is_null()) {
None => None,
Some(value) => {
use crate::terms::behavioral_head::AuxOutcomeFamily;
let ao = value
.as_object()
.ok_or_else(|| format!("latents['{key}'].aux_outcome must be an object"))?;
let family = match ao
.get("family")
.and_then(JsonValue::as_str)
.unwrap_or("binomial")
.to_ascii_lowercase()
.as_str()
{
"binomial" => AuxOutcomeFamily::Binomial,
"multinomial" => {
let n_classes = ao
.get("n_classes")
.and_then(JsonValue::as_u64)
.ok_or_else(|| {
format!(
"latents['{key}'].aux_outcome.n_classes is required for multinomial"
)
})? as usize;
AuxOutcomeFamily::Multinomial { n_classes }
}
other => {
return Err(format!(
"latents['{key}'].aux_outcome.family must be 'binomial' or 'multinomial', got '{other}'"
));
}
};
let y = json_array1(
ao.get("y")
.ok_or_else(|| format!("latents['{key}'].aux_outcome.y is required"))?,
&format!("latents['{key}'].aux_outcome.y"),
)?;
if y.len() != n {
return Err(format!(
"latents['{key}'].aux_outcome.y has length {}, expected n = {n}",
y.len()
));
}
let row_weights = ao
.get("row_weights")
.filter(|value| !value.is_null())
.map(|value| {
json_array1(value, &format!("latents['{key}'].aux_outcome.row_weights"))
})
.transpose()?;
if let Some(w) = row_weights.as_ref()
&& w.len() != n
{
return Err(format!(
"latents['{key}'].aux_outcome.row_weights has length {}, expected n = {n}",
w.len()
));
}
let init_log_precision = ao
.get("init_log_precision")
.map(|value| {
json_array1(
value,
&format!("latents['{key}'].aux_outcome.init_log_precision"),
)
})
.transpose()?;
Some(LatentAuxOutcomeSpec {
family,
y,
row_weights,
init_log_precision,
})
}
};
if dim_selection.is_some() && aux_prior.is_none() && aux_outcome.is_none() {
return Err(format!(
"latents['{key}'] uses dim_selection without aux_prior or aux_outcome; ARD alone is not an identifiable latent-coordinate gauge"
));
}
if aux_outcome.is_some() && aux_prior.is_some() {
return Err(format!(
"latents['{key}'] specifies both aux_prior and aux_outcome; the auxiliary signal is either a prior (gauge-pin covariate) or a modeled outcome (behavioral head), not both"
));
}
let explicit_none_mode = obj
.get("id_mode")
.or_else(|| obj.get("mode"))
.and_then(JsonValue::as_str)
.is_some_and(|s| s.eq_ignore_ascii_case("none"));
if aux_prior.is_none()
&& dim_selection.is_none()
&& aux_outcome.is_none()
&& !explicit_none_mode
{
return Err(format!(
"latents['{key}'] requires aux_prior or aux_outcome for identifiable joint REML; pass id_mode='none' only when a separate gauge fix is supplied"
));
}
specs.push(LatentSpec {
target,
n,
d,
init,
manifold,
retraction_registry,
aux_prior,
dim_selection,
aux_outcome,
explicit_none_mode,
});
}
Ok(specs)
}
fn deterministic_unit(seed: &mut u64) -> f64 {
*seed = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((*seed >> 11) as f64) * (1.0 / ((1_u64 << 53) as f64))
}
fn initial_latent_matrix(spec: &LatentSpec, y: ArrayView1<'_, f64>) -> Result<Array2<f64>, String> {
match &spec.init {
LatentInitSpec::Explicit(matrix) => {
if matrix.nrows() != spec.n || matrix.ncols() != spec.d {
return Err(format!(
"latent '{}' explicit init has shape {}x{}, expected {}x{}",
spec.target,
matrix.nrows(),
matrix.ncols(),
spec.n,
spec.d
));
}
Ok(matrix.clone())
}
LatentInitSpec::Random => {
let mut seed = 0x9E3779B97F4A7C15_u64 ^ ((spec.n as u64) << 32) ^ spec.d as u64;
let mut out = Array2::<f64>::zeros((spec.n, spec.d));
for value in out.iter_mut() {
*value = deterministic_unit(&mut seed);
}
Ok(out)
}
LatentInitSpec::Pca => {
let mut out = Array2::<f64>::zeros((spec.n, spec.d));
let mean = y.iter().sum::<f64>() / y.len().max(1) as f64;
let var = y
.iter()
.map(|v| {
let centered = *v - mean;
centered * centered
})
.sum::<f64>()
/ y.len().max(1) as f64;
let sd = var.sqrt().max(1e-12);
for n in 0..spec.n {
out[[n, 0]] = (y[n] - mean) / sd;
}
if spec.d > 1 {
let mut seed = 0xD1B54A32D192ED03_u64 ^ ((spec.n as u64) << 16) ^ spec.d as u64;
for n in 0..spec.n {
for axis in 1..spec.d {
out[[n, axis]] = deterministic_unit(&mut seed) - 0.5;
}
}
}
Ok(out)
}
}
}
fn latent_id_mode(spec: &LatentSpec) -> Result<LatentIdMode, String> {
if let Some(ao) = spec.aux_outcome.as_ref() {
use crate::terms::behavioral_head::BehavioralHead;
if let Some(init) = ao.init_log_precision.as_ref()
&& init.len() != spec.d
{
return Err(format!(
"latent '{}' aux_outcome.init_log_precision has length {}, expected {}",
spec.target,
init.len(),
spec.d
));
}
let head = match ao.row_weights.as_ref() {
Some(w) => BehavioralHead::new(ao.family, ao.y.clone(), w.clone()),
None => BehavioralHead::fully_supervised(ao.family, ao.y.clone()),
}
.map_err(|e| format!("latent '{}' aux_outcome head: {e}", spec.target))?;
return Ok(LatentIdMode::AuxOutcome {
head,
init_log_precision: ao.init_log_precision.clone(),
});
}
match (&spec.aux_prior, &spec.dim_selection) {
(Some(aux), Some(dim)) => {
if let Some(init) = dim.init_log_precision.as_ref()
&& init.len() != spec.d
{
return Err(format!(
"latent '{}' dim_selection.init_log_precision has length {}, expected {}",
spec.target,
init.len(),
spec.d
));
}
Ok(LatentIdMode::AuxPriorDimSelection {
u: aux.u.clone(),
family: aux.family,
strength: aux.strength,
init_log_precision: dim.init_log_precision.clone(),
})
}
(Some(aux), None) => Ok(LatentIdMode::AuxPrior {
u: aux.u.clone(),
family: aux.family,
strength: aux.strength,
}),
(None, None) if spec.explicit_none_mode => Ok(LatentIdMode::None),
(None, None) => Err(format!(
"latent '{}' requires aux_prior for identifiable joint REML; pass id_mode='none' only when a separate gauge fix is supplied",
spec.target
)),
(None, Some(_)) => Err(format!(
"latent '{}' dim_selection requires aux_prior for identifiability",
spec.target
)),
}
}
fn prepare_standard_latent_coord(
parsed: &ParsedFormula,
data: &Dataset,
y: ArrayView1<'_, f64>,
config: &FitConfig,
) -> Result<Option<(Dataset, ParsedFormula, StandardLatentCoordConfig)>, String> {
let specs = parse_latent_specs(config.latents.as_ref())?;
let analytic_penalties = descriptors::build_analytic_penalty_registry_from_descriptors(
config.latents.as_ref(),
config.analytic_penalties.as_ref(),
)?;
if config.topology_auto_selector.is_some() && specs.is_empty() {
return Err(
"TopologyAutoSelector requires a Smooth with latent coords; pass latents={...}"
.to_string(),
);
}
if specs.is_empty() {
return Ok(None);
}
if specs.len() != 1 {
return Err(
"standard latent-coordinate REML currently accepts exactly one latent smooth term"
.to_string(),
);
}
let spec = specs.into_iter().next().unwrap();
if let Some(selector) = config.topology_auto_selector.as_ref()
&& let Some(requested) = selector.latent.as_ref()
&& requested != &spec.target
{
return Err(format!(
"TopologyAutoSelector requested latent {requested:?}, but the formula path materialized latent {:?}",
spec.target
));
}
if spec.n != data.values.nrows() || spec.n != y.len() {
return Err(format!(
"latent '{}' row count {} does not match data rows {}",
spec.target,
spec.n,
data.values.nrows()
));
}
if let Some(aux) = spec.aux_prior.as_ref()
&& aux.u.nrows() != spec.n
{
return Err(format!(
"latent '{}' aux_prior.u has {} rows, expected {}",
spec.target,
aux.u.nrows(),
spec.n
));
}
let matrix = initial_latent_matrix(&spec, y)?;
let id_mode = latent_id_mode(&spec)?;
let latent_values = Arc::new(LatentCoordValues::from_matrix_with_manifold_and_retraction(
matrix.view(),
id_mode,
spec.manifold.manifold.clone(),
spec.retraction_registry.clone(),
));
let base_cols = data.values.ncols();
let mut values = Array2::<f64>::zeros((data.values.nrows(), base_cols + spec.d));
values.slice_mut(s![.., ..base_cols]).assign(&data.values);
let mut headers = data.headers.clone();
let mut columns = data.schema.columns.clone();
let mut column_kinds = data.column_kinds.clone();
let mut synthetic_vars = Vec::with_capacity(spec.d);
let mut feature_cols = Vec::with_capacity(spec.d);
for axis in 0..spec.d {
let name = format!("{}__latent{}", spec.target, axis);
let col = base_cols + axis;
values.column_mut(col).assign(&matrix.column(axis));
headers.push(name.clone());
columns.push(SchemaColumn {
name: name.clone(),
kind: ColumnKindTag::Continuous,
levels: Vec::new(),
});
column_kinds.push(ColumnKindTag::Continuous);
synthetic_vars.push(name);
feature_cols.push(col);
}
let augmented = Dataset {
headers,
values,
schema: DataSchema { columns },
column_kinds,
};
let mut rewritten = parsed.clone();
let mut matched = false;
for term in &mut rewritten.terms {
if let ParsedTerm::Smooth { vars, .. } = term
&& vars.len() == 1
&& vars[0] == spec.target
{
*vars = synthetic_vars.clone();
matched = true;
}
}
if !matched {
return Err(format!(
"latents provided '{}' but no formula smooth term s({}, ...) was found",
spec.target, spec.target
));
}
Ok(Some((
augmented,
rewritten,
StandardLatentCoordConfig {
values: latent_values,
term_index: crate::types::SmoothTermIdx::placeholder(),
feature_cols,
manifold: spec.manifold.manifold,
manifold_auto: spec.manifold.auto,
retraction_registry: spec.retraction_registry,
analytic_penalties: (!analytic_penalties.penalties.is_empty())
.then(|| Arc::new(analytic_penalties)),
},
)))
}
fn smooth_basis_feature_cols_for_latent(
basis: &crate::smooth::SmoothBasisSpec,
) -> Option<Vec<usize>> {
match basis {
crate::smooth::SmoothBasisSpec::BSpline1D { feature_col, .. } => Some(vec![*feature_col]),
crate::smooth::SmoothBasisSpec::ThinPlate { feature_cols, .. }
| crate::smooth::SmoothBasisSpec::Sphere { feature_cols, .. }
| crate::smooth::SmoothBasisSpec::ConstantCurvature { feature_cols, .. }
| crate::smooth::SmoothBasisSpec::Matern { feature_cols, .. }
| crate::smooth::SmoothBasisSpec::MeasureJet { feature_cols, .. }
| crate::smooth::SmoothBasisSpec::Duchon { feature_cols, .. }
| crate::smooth::SmoothBasisSpec::Pca { feature_cols, .. }
| crate::smooth::SmoothBasisSpec::TensorBSpline { feature_cols, .. } => {
Some(feature_cols.clone())
}
crate::smooth::SmoothBasisSpec::BySmooth { smooth, .. } => {
smooth_basis_feature_cols_for_latent(smooth)
}
crate::smooth::SmoothBasisSpec::ByVariable { inner, .. }
| crate::smooth::SmoothBasisSpec::FactorSumToZero { inner, .. } => {
smooth_basis_feature_cols_for_latent(inner)
}
crate::smooth::SmoothBasisSpec::FactorSmooth { .. } => None,
}
}
fn natural_latent_manifold_for_basis(
basis: &crate::smooth::SmoothBasisSpec,
d: usize,
) -> LatentManifold {
match basis {
crate::smooth::SmoothBasisSpec::BSpline1D { spec, .. } => {
if let crate::basis::BSplineKnotSpec::PeriodicUniform { data_range, .. } =
&spec.knotspec
{
LatentManifold::Circle {
period: data_range.1 - data_range.0,
}
} else {
LatentManifold::Euclidean
}
}
crate::smooth::SmoothBasisSpec::Sphere { .. } => LatentManifold::Sphere { dim: d },
crate::smooth::SmoothBasisSpec::Duchon { spec, .. }
if spec.periodic.is_some() && d == 1 =>
{
let period = spec
.periodic
.as_ref()
.and_then(|v| v.first().copied().flatten())
.unwrap_or(std::f64::consts::TAU);
LatentManifold::Circle { period }
}
crate::smooth::SmoothBasisSpec::TensorBSpline { spec, .. } => {
let parts: Vec<LatentManifold> = spec
.marginalspecs
.iter()
.map(|margin| {
if let crate::basis::BSplineKnotSpec::PeriodicUniform { data_range, .. } =
&margin.knotspec
{
LatentManifold::Circle {
period: data_range.1 - data_range.0,
}
} else {
LatentManifold::Euclidean
}
})
.collect();
if parts.iter().all(|part| part.is_euclidean()) {
LatentManifold::Euclidean
} else {
LatentManifold::Product(parts)
}
}
crate::smooth::SmoothBasisSpec::BySmooth { smooth, .. } => {
natural_latent_manifold_for_basis(smooth, d)
}
crate::smooth::SmoothBasisSpec::ByVariable { inner, .. }
| crate::smooth::SmoothBasisSpec::FactorSumToZero { inner, .. } => {
natural_latent_manifold_for_basis(inner, d)
}
crate::smooth::SmoothBasisSpec::ThinPlate { .. }
| crate::smooth::SmoothBasisSpec::ConstantCurvature { .. }
| crate::smooth::SmoothBasisSpec::Matern { .. }
| crate::smooth::SmoothBasisSpec::MeasureJet { .. }
| crate::smooth::SmoothBasisSpec::Duchon { .. }
| crate::smooth::SmoothBasisSpec::Pca { .. }
| crate::smooth::SmoothBasisSpec::FactorSmooth { .. } => LatentManifold::Euclidean,
}
}
pub(crate) fn materialize_standard<'a>(
parsed: &ParsedFormula,
data: &'a Dataset,
col_map: &HashMap<String, usize>,
config: &FitConfig,
) -> Result<MaterializedModel<'a>, WorkflowError> {
if config.noise_offset_column.is_some() {
return Err(
"noise_offset_column requires a location-scale model with noise_formula"
.to_string()
.into(),
);
}
let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
let y = data.values.column(y_col).to_owned();
let y_kind = response_column_kind(data, y_col);
let mut inference_notes = Vec::new();
let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
let family = resolve_family(
config.family.as_deref(),
config.negative_binomial_theta,
link_choice.as_ref(),
y.view(),
y_kind,
&parsed.response,
)?;
family
.response
.validate_response_support(y.view())
.map_err(|violation| violation.message_for(&parsed.response))?;
family
.response
.validate_response_degeneracy(y.view())
.map_err(|deg| deg.message_for(&parsed.response))?;
reject_explicit_linkwiggle_for_nonbinomial(parsed, &family)?;
let effective_linkwiggle =
effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
let latent_prepared = prepare_standard_latent_coord(parsed, data, y.view(), config)?;
let (latent_dataset, latent_parsed, mut latent_coord) = match latent_prepared {
Some((dataset, parsed, coord)) => (Some(dataset), Some(parsed), Some(coord)),
None => (None, None, None),
};
let term_data = latent_dataset.as_ref().unwrap_or(data);
let term_parsed = latent_parsed.as_ref().unwrap_or(parsed);
let term_col_map = term_data.column_map();
let policy =
resolved_resource_policy(config, term_data, crate::resource::ProblemHints::default());
let spec = build_termspec_with_geometry_and_overrides(
&term_parsed.terms,
term_data,
&term_col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
check_smooth_capacity(&spec, y.len(), &parsed.response)?;
if let Some(coord) = latent_coord.as_mut() {
let resolved_idx = spec
.smooth_terms
.iter()
.position(|term| {
smooth_basis_feature_cols_for_latent(&term.basis)
.is_some_and(|cols| cols == coord.feature_cols)
})
.ok_or_else(|| {
"latent-coordinate smooth term disappeared during formula materialization"
.to_string()
})?;
coord.term_index = crate::types::SmoothTermIdx::new(resolved_idx);
if coord.manifold_auto {
let inferred = natural_latent_manifold_for_basis(
&spec.smooth_terms[coord.term_index.get()].basis,
coord.feature_cols.len(),
);
coord.manifold = inferred.clone();
coord.values = Arc::new(coord.values.with_manifold(inferred));
}
}
let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
let offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
let latent_cloglog = if family.is_latent_cloglog() {
let sigma = match config.frailty.clone().unwrap_or(FrailtySpec::None) {
FrailtySpec::HazardMultiplier {
sigma_fixed: Some(sigma),
loading: crate::families::lognormal_kernel::HazardLoading::Full,
} => sigma,
FrailtySpec::HazardMultiplier {
sigma_fixed: Some(_),
loading,
} => {
return Err(WorkflowError::MissingDependency {
reason: format!(
"latent-cloglog-binomial requires HazardLoading::Full, got {loading:?}"
),
}
.into());
}
FrailtySpec::HazardMultiplier {
sigma_fixed: None, ..
} => {
return Err(WorkflowError::MissingDependency {
reason:
"latent-cloglog-binomial currently requires a fixed hazard-multiplier sigma"
.to_string(),
}
.into());
}
FrailtySpec::GaussianShift { .. } => {
return Err(WorkflowError::InvalidConfig {
reason: "latent-cloglog-binomial does not support GaussianShift frailty"
.to_string(),
}
.into());
}
FrailtySpec::None => {
return Err(WorkflowError::MissingDependency {
reason:
"latent-cloglog-binomial requires config.frailty=HazardMultiplier with a fixed sigma"
.to_string(),
}
.into());
}
};
Some(
LatentCLogLogState::new(sigma)
.map_err(|e| format!("invalid latent_cloglog state: {e}"))?,
)
} else {
if config.frailty.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"config.frailty is not supported for standard family {:?}; use a frailty-aware family instead",
family
),
}
.into());
}
None
};
let options = FitOptions {
latent_cloglog,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
skip_rho_posterior_inference: true,
max_iter: config.outer_max_iter.unwrap_or(200),
tol: 1e-10,
nullspace_dims: vec![],
linear_constraints: None,
firth_bias_reduction: config.firth,
adaptive_regularization: standard_adaptive_regularization_options(config),
penalty_shrinkage_floor: Some(1e-6),
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
persist_warm_start_disk: config.persist_warm_start_disk,
};
let kappa_options = SpatialLengthScaleOptimizationOptions::default();
let wiggle = effective_linkwiggle.as_ref().and_then(|cfg| {
if !family.is_binomial() {
return None;
}
let link_kind = match link_choice.as_ref() {
Some(c) => match StandardLink::try_from(c.link) {
Ok(std_link) => InverseLink::Standard(std_link),
Err(_) => return None,
},
None => {
if let Some(state) = latent_cloglog {
InverseLink::LatentCLogLog(state)
} else {
InverseLink::Standard(StandardLink::Logit)
}
}
};
Some(StandardBinomialWiggleConfig {
link_kind,
wiggle: LinkWiggleConfig {
degree: cfg.degree,
num_internal_knots: cfg.num_internal_knots,
penalty_orders: cfg.penalty_orders.clone(),
double_penalty: cfg.double_penalty,
},
refit_options: BlockwiseFitOptions::default(),
})
});
Ok(MaterializedModel {
request: FitRequest::Standard(StandardFitRequest {
data: term_data.values.clone(),
y,
weights,
offset,
spec,
family,
options,
kappa_options,
wiggle,
coefficient_groups: config.coefficient_groups.clone(),
penalty_block_gamma_priors: config.penalty_block_gamma_priors.clone(),
latent_coord,
_marker: std::marker::PhantomData,
}),
inference_notes,
})
}
pub(crate) fn materialize_bernoulli_marginal_slope<'a>(
parsed: &ParsedFormula,
data: &'a Dataset,
col_map: &HashMap<String, usize>,
config: &FitConfig,
) -> Result<MaterializedModel<'a>, WorkflowError> {
let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
let y = data.values.column(y_col).to_owned();
if !is_binary_response(y.view()) {
return Err(WorkflowError::SchemaMismatch {
reason: "Bernoulli marginal-slope requires a binary {0,1} response".to_string(),
}
.into());
}
if config.noise_formula.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "Bernoulli marginal-slope cannot also use noise_formula".to_string(),
}
.into());
}
let logslope_formula = config
.logslope_formula
.as_deref()
.ok_or_else(|| "Bernoulli marginal-slope requires logslope_formula".to_string())?;
let z_column = config.z_column.as_deref();
if z_column.is_none() && config.ctn_stage1.is_none() {
return Err(WorkflowError::InvalidConfig {
reason: "Bernoulli marginal-slope requires z_column (or a CTN Stage-1 recipe via \
ctn_stage1, which produces z by cross-fitting)"
.to_string(),
});
}
let (_, parsed_logslope) =
parse_matching_auxiliary_formula(logslope_formula, &parsed.response, "logslope_formula")?;
if parsed_logslope.linkspec.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "link(...) is not supported inside logslope_formula".to_string(),
}
.into());
}
if let Some(z_column) = z_column {
validate_marginal_slope_z_column_exclusion(
parsed,
&parsed_logslope,
z_column,
"Bernoulli marginal-slope",
"logslope_formula",
)?;
}
let mut inference_notes = Vec::new();
let policy = resolved_resource_policy(
config,
data,
crate::resource::ProblemHints {
marginal_slope_large_scale_active: true,
},
);
let aliased_col_map = match z_column {
Some(z_column) => column_map_with_alias(col_map, "z", z_column),
None => col_map.clone(),
};
let mut marginalspec = build_termspec_with_geometry_and_overrides(
&parsed.terms,
data,
&aliased_col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
prune_unidentified_linear_terms_for_marginal_slope(
&mut marginalspec,
data,
"bernoulli marginal-slope marginal formula",
&mut inference_notes,
)?;
let mut logslopespec = build_termspec_with_geometry_and_overrides(
&parsed_logslope.terms,
data,
&aliased_col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
prune_unidentified_linear_terms_for_marginal_slope(
&mut logslopespec,
data,
"bernoulli marginal-slope logslope_formula",
&mut inference_notes,
)?;
let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
let marginal_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
let logslope_offset =
resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
let routing = route_marginal_slope_deviation_blocks(
parsed.linkwiggle.as_ref(),
parsed_logslope.linkwiggle.as_ref(),
)?;
let (z, score_influence_jacobian) =
match crossfit_score_calibration(data, col_map, config.ctn_stage1.as_ref(), &policy)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })?
{
Some(calibration) => (calibration.z_oof, Some(calibration.jac_oof)),
None => {
let z_column = z_column.expect("z_column presence checked when ctn_stage1 is None");
let z_idx = resolve_role_col(col_map, z_column, "z")?;
let z = data.values.column(z_idx).to_owned();
validate_bernoulli_marginal_slope_z_column_variance(
z_column,
z.view(),
weights.view(),
)?;
(z, None)
}
};
let spec = BernoulliMarginalSlopeTermSpec {
y,
weights,
z,
base_link: InverseLink::Standard(StandardLink::Probit),
marginalspec,
logslopespec,
marginal_offset,
logslope_offset,
frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
score_warp: routing.score_warp,
link_dev: routing.link_dev,
latent_z_policy: Default::default(),
score_influence_jacobian,
};
Ok(MaterializedModel {
request: FitRequest::BernoulliMarginalSlope(BernoulliMarginalSlopeFitRequest {
data: data.values.view(),
spec,
options: BlockwiseFitOptions {
compute_covariance: true,
..Default::default()
},
kappa_options: SpatialLengthScaleOptimizationOptions::default(),
policy,
}),
inference_notes,
})
}
pub(crate) fn materialize_survival<'a>(
parsed: &ParsedFormula,
data: &'a Dataset,
col_map: &HashMap<String, usize>,
config: &FitConfig,
entry_col: Option<&str>,
exit_col: &str,
event_col: &str,
interval_right_col: Option<&str>,
) -> Result<MaterializedModel<'a>, WorkflowError> {
let mut inference_notes = Vec::new();
let entry_idx = entry_col
.map(|name| resolve_role_col(col_map, name, "entry"))
.transpose()?;
let exit_idx = resolve_role_col(col_map, exit_col, "exit")?;
let event_idx = resolve_role_col(col_map, event_col, "event")?;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let n = data.values.nrows();
let event = data.values.column(event_idx).to_owned();
let event_codes = Array1::from_iter(
event
.iter()
.copied()
.enumerate()
.map(|(i, value)| crate::survival::survival_event_code_from_value(value, i))
.collect::<Result<Vec<_>, _>>()?,
);
let pairs: Result<Vec<(f64, f64)>, String> = (0..n)
.into_par_iter()
.map(|i| {
let entry_val = entry_idx.map_or(0.0, |idx| data.values[[i, idx]]);
normalize_survival_time_pair(entry_val, data.values[[i, exit_idx]], i)
})
.collect();
let pairs = pairs?;
let mut age_entry = Array1::<f64>::zeros(n);
let mut age_exit = Array1::<f64>::zeros(n);
for (i, (e, x)) in pairs.into_iter().enumerate() {
age_entry[i] = e;
age_exit[i] = x;
}
let age_right = if let Some(right_col) = interval_right_col {
let right_idx = resolve_role_col(col_map, right_col, "interval right")?;
let mut right = Array1::<f64>::zeros(n);
for i in 0..n {
let r = data.values[[i, right_idx]];
let is_bracketed = data.values[[i, event_idx]] >= 0.5;
if is_bracketed {
if !(r.is_finite()) || r < age_exit[i] {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"SurvInterval(L, R, event) requires a finite R >= L on bracketed rows (event >= 1); row {} has L={}, R={r}",
i + 1,
age_exit[i]
),
});
}
right[i] = r;
} else {
right[i] = age_exit[i];
}
}
Some(right)
} else {
None
};
let survival_mode = parse_survival_likelihood_mode(&config.survival_likelihood)?;
if age_right.is_some() && survival_mode != SurvivalLikelihoodMode::Latent {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"interval-censored SurvInterval(L, R, event) is only defined for the latent \
hazard-window survival likelihood (its kernel carries the log[S(L) − S(R)] \
interval contribution); got survival_likelihood='{}'",
config.survival_likelihood
),
});
}
if !event_codes.iter().any(|&code| code > 0) {
let mode_label = match survival_mode {
SurvivalLikelihoodMode::MarginalSlope => "survival marginal-slope",
_ => "survival fit",
};
return Err(WorkflowError::InvalidConfig {
reason: format!(
"{mode_label} requires at least one target event; all rows are censored, so the likelihood has no event score and cannot identify the hazard"
),
});
}
let cause_count =
crate::survival::cause_count_from_event_codes(event_codes.view()).into_workflow_result()?;
if cause_count > 1
&& !matches!(
survival_mode,
SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
)
{
return Err(WorkflowError::InvalidConfig {
reason: format!(
"cause-specific competing risks with {cause_count} causes are currently supported for survival_likelihood='transformation' and 'weibull'; got '{}'",
config.survival_likelihood
),
}
.into());
}
if parsed.linkwiggle.is_some()
&& !matches!(
survival_mode,
SurvivalLikelihoodMode::LocationScale | SurvivalLikelihoodMode::MarginalSlope
)
{
return Err(WorkflowError::InvalidConfig {
reason: format!(
"linkwiggle(...) is not defined for survival_likelihood='{}'",
config.survival_likelihood
),
}
.into());
}
if parsed.linkspec.is_some()
&& matches!(
survival_mode,
SurvivalLikelihoodMode::Transformation
| SurvivalLikelihoodMode::Weibull
| SurvivalLikelihoodMode::Latent
| SurvivalLikelihoodMode::LatentBinary
)
{
return Err(WorkflowError::InvalidConfig {
reason: format!(
"link(...) is not implemented for survival_likelihood='{}'",
config.survival_likelihood
),
}
.into());
}
if matches!(survival_mode, SurvivalLikelihoodMode::MarginalSlope)
&& let Some(z_column) = config.z_column.as_deref()
{
let logslope_parsed_for_check = match config.logslope_formula.as_deref() {
Some(ls_formula) => Some(
parse_matching_auxiliary_formula(ls_formula, &parsed.response, "logslope_formula")?
.1,
),
None => None,
};
let logslope_ref = logslope_parsed_for_check.as_ref().unwrap_or(parsed);
validate_marginal_slope_z_column_exclusion(
parsed,
logslope_ref,
z_column,
"survival marginal-slope",
"logslope_formula",
)?;
}
let effective_timewiggle = parsed.timewiggle.clone();
let baseline_target_raw = match survival_mode {
SurvivalLikelihoodMode::Weibull if effective_timewiggle.is_some() => "weibull",
SurvivalLikelihoodMode::Weibull => "linear",
_ => &config.baseline_target,
};
let baseline_cfg = initial_survival_baseline_config_for_fit(
baseline_target_raw,
config.baseline_scale,
config.baseline_shape,
config.baseline_rate,
config.baseline_makeham,
&age_exit,
)?;
if matches!(
survival_mode,
SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
) && baseline_cfg.target == SurvivalBaselineTarget::Linear
{
return Err(
"latent hazard-window families require a non-linear scalar baseline target; use baseline_target weibull, gompertz, or gompertz-makeham"
.to_string()
.into(),
);
}
let time_cfg = if effective_timewiggle.is_some() {
SurvivalTimeBasisConfig::None
} else if survival_mode == SurvivalLikelihoodMode::Weibull {
SurvivalTimeBasisConfig::Linear
} else {
parse_survival_time_basis_config(
&config.time_basis,
config.time_degree,
config.time_num_internal_knots,
config.time_smooth_lambda,
)?
};
let time_anchor = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
resolve_survival_marginal_slope_time_anchor_value(&age_entry, &age_exit, None)?
} else {
resolve_survival_time_anchor_value(&age_entry, None)?
};
let exact_derivative_guard = survival_derivative_guard_for_likelihood(survival_mode);
let mut time_build = build_survival_time_basis(
&age_entry,
&age_exit,
time_cfg.clone(),
Some((config.time_num_internal_knots, config.time_smooth_lambda)),
)?;
if survival_mode != SurvivalLikelihoodMode::Weibull && effective_timewiggle.is_none() {
require_structural_survival_time_basis(&time_build.basisname, "workflow survival fitting")?;
}
let resolved_time_cfg = resolved_survival_time_basis_config_from_build(
&time_build.basisname,
time_build.degree,
time_build.knots.as_ref(),
time_build.keep_cols.as_ref(),
time_build.smooth_lambda,
)?;
let time_anchor_row = evaluate_survival_time_basis_row(time_anchor, &resolved_time_cfg)?;
center_survival_time_designs_at_anchor(
&mut time_build.x_entry_time,
&mut time_build.x_exit_time,
&time_anchor_row,
)?;
let time_build_right = if let Some(age_right) = age_right.as_ref() {
let mut build_right = build_survival_time_basis(
&age_entry,
age_right,
resolved_time_cfg.clone(),
Some((config.time_num_internal_knots, config.time_smooth_lambda)),
)?;
center_survival_time_designs_at_anchor(
&mut build_right.x_entry_time,
&mut build_right.x_exit_time,
&time_anchor_row,
)?;
Some(build_right)
} else {
None
};
if effective_timewiggle.is_some() && baseline_cfg.target == SurvivalBaselineTarget::Linear {
return Err(
"timewiggle requires a non-linear scalar survival baseline target; \
use baseline_target weibull, gompertz, or gompertz-makeham"
.to_string()
.into(),
);
}
let policy = resolved_resource_policy(
config,
data,
crate::resource::ProblemHints {
marginal_slope_large_scale_active: survival_mode
== SurvivalLikelihoodMode::MarginalSlope,
},
);
let marginal_slope_aliased_col_map = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
match config.z_column.as_deref() {
Some(z_column) => Some(column_map_with_alias(col_map, "z", z_column)),
None if config.ctn_stage1.is_some() => None,
None => {
return Err(WorkflowError::InvalidConfig {
reason: "marginal-slope survival requires z_column in FitConfig (or a CTN \
Stage-1 recipe via ctn_stage1, which produces z by cross-fitting)"
.to_string(),
});
}
}
} else {
None
};
let termspec_col_map = marginal_slope_aliased_col_map.as_ref().unwrap_or(col_map);
let mut termspec = build_termspec_with_geometry_and_overrides(
&parsed.terms,
data,
termspec_col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
prune_unidentified_linear_terms_for_marginal_slope(
&mut termspec,
data,
"survival marginal-slope marginal formula",
&mut inference_notes,
)?;
}
let residual_dist = parse_survival_distribution(&config.survival_distribution)?;
let survival_inverse_link = residual_distribution_inverse_link(residual_dist);
let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
let effective_linkwiggle =
effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
let effective_linkwiggle_cfg = effective_linkwiggle.clone().map(|cfg| LinkWiggleConfig {
degree: cfg.degree,
num_internal_knots: cfg.num_internal_knots,
penalty_orders: cfg.penalty_orders,
double_penalty: cfg.double_penalty,
});
let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
let threshold_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
let log_sigma_offset =
resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
let threshold_template = if let Some(k) = config.threshold_time_k {
build_time_varying_survival_covariate_template(
&age_entry,
&age_exit,
k,
config.threshold_time_degree,
"threshold",
)?
} else {
SurvivalCovariateTermBlockTemplate::Static
};
let log_sigma_template = if let Some(k) = config.sigma_time_k {
build_time_varying_survival_covariate_template(
&age_entry,
&age_exit,
k,
config.sigma_time_degree,
"sigma",
)?
} else {
SurvivalCovariateTermBlockTemplate::Static
};
let log_sigmaspec = if let Some(noise) = config.noise_formula.as_deref() {
let mut noise_parsed = parse_formula(&format!("{} ~ {noise}", parsed.response))?;
apply_secondary_predictor_basis_parsimony(&mut noise_parsed.terms, data.values.nrows());
build_termspec_with_geometry_and_overrides(
&noise_parsed.terms,
data,
termspec_col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?
} else {
TermCollectionSpec {
linear_terms: vec![],
random_effect_terms: vec![],
smooth_terms: vec![],
}
};
let marginal_z_column_name = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
match config.z_column.as_deref() {
Some(name) => Some(name),
None if config.ctn_stage1.is_some() => None,
None => {
return Err(WorkflowError::InvalidConfig {
reason: "marginal-slope survival requires z_column in FitConfig (or a CTN \
Stage-1 recipe via ctn_stage1, which produces z by cross-fitting)"
.to_string(),
});
}
}
} else {
None
};
let (
marginal_z,
marginal_logslopespec,
marginal_logslopespecs,
marginal_slope_deviation_routing,
marginal_slope_base_link,
) = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
let base_link = resolve_survival_marginal_slope_base_link(parsed.linkspec.as_ref())?;
if marginal_z_column_name.is_none() {
let placeholder_z = Array2::<f64>::zeros((data.values.nrows(), 1));
let (logslopespec, routing) = if let Some(ls_formula) =
config.logslope_formula.as_deref()
{
let (_, ls_parsed) = parse_matching_auxiliary_formula(
ls_formula,
&parsed.response,
"logslope_formula",
)?;
if ls_parsed.linkspec.is_some() {
return Err(
"link(...) is not supported in logslope_formula for the survival marginal-slope family"
.to_string()
.into(),
);
}
if ls_parsed.timewiggle.is_some() {
return Err(
"timewiggle(...) is not supported in logslope_formula for the survival marginal-slope family"
.to_string()
.into(),
);
}
if ls_parsed.survivalspec.is_some() {
return Err(
"survmodel(...) is not supported in logslope_formula for the survival marginal-slope family"
.to_string()
.into(),
);
}
let mut spec = build_termspec_with_geometry_and_overrides(
&ls_parsed.terms,
data,
col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
prune_unidentified_linear_terms_for_marginal_slope(
&mut spec,
data,
"survival marginal-slope logslope_formula",
&mut inference_notes,
)?;
let routing = route_marginal_slope_deviation_blocks(
parsed.linkwiggle.as_ref(),
ls_parsed.linkwiggle.as_ref(),
)?;
(spec, routing)
} else {
(
termspec.clone(),
route_marginal_slope_deviation_blocks(parsed.linkwiggle.as_ref(), None)?,
)
};
(
Some(placeholder_z),
Some(logslopespec.clone()),
Some(vec![logslopespec]),
routing,
Some(base_link),
)
} else if let Some(ls_formula) = config.logslope_formula.as_deref() {
let default_z_column = marginal_z_column_name.expect("z column present when no recipe");
let (_, ls_parsed) =
parse_matching_auxiliary_formula(ls_formula, &parsed.response, "logslope_formula")?;
if ls_parsed.linkspec.is_some() {
return Err(
"link(...) is not supported in logslope_formula for the survival marginal-slope family"
.to_string()
.into(),
);
}
if ls_parsed.timewiggle.is_some() {
return Err(
"timewiggle(...) is not supported in logslope_formula for the survival marginal-slope family"
.to_string()
.into(),
);
}
if ls_parsed.survivalspec.is_some() {
return Err(
"survmodel(...) is not supported in logslope_formula for the survival marginal-slope family"
.to_string()
.into(),
);
}
validate_marginal_slope_z_column_exclusion(
parsed,
&ls_parsed,
default_z_column,
"survival marginal-slope",
"logslope_formula",
)?;
let surfaces = marginal_slope_logslope_surfaces(&ls_parsed, default_z_column)?;
let mut z = Array2::<f64>::zeros((data.values.nrows(), surfaces.len()));
let mut specs = Vec::with_capacity(surfaces.len());
for (surface_idx, surface) in surfaces.iter().enumerate() {
let z_idx = resolve_role_col(col_map, &surface.z_column, "z")?;
z.column_mut(surface_idx).assign(&data.values.column(z_idx));
let aliased_col_map = column_map_with_alias(col_map, "z", &surface.z_column);
let mut spec = build_termspec_with_geometry_and_overrides(
&surface.terms,
data,
&aliased_col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
prune_unidentified_linear_terms_for_marginal_slope(
&mut spec,
data,
"survival marginal-slope logslope_formula",
&mut inference_notes,
)?;
specs.push(spec);
}
(
Some(z),
specs.first().cloned(),
Some(specs),
route_marginal_slope_deviation_blocks(
parsed.linkwiggle.as_ref(),
ls_parsed.linkwiggle.as_ref(),
)?,
Some(base_link),
)
} else {
let default_z_column = marginal_z_column_name.expect("z column present when no recipe");
validate_marginal_slope_z_column_exclusion(
parsed,
parsed,
default_z_column,
"survival marginal-slope",
"logslope_formula",
)?;
let z_idx = resolve_role_col(col_map, default_z_column, "z")?;
let z = data.values.column(z_idx).to_owned().insert_axis(Axis(1));
(
Some(z),
Some(termspec.clone()),
Some(vec![termspec.clone()]),
route_marginal_slope_deviation_blocks(parsed.linkwiggle.as_ref(), None)?,
Some(base_link),
)
}
} else {
(
None,
None,
None,
MarginalSlopeDeviationRouting {
score_warp: None,
link_dev: None,
},
None,
)
};
let marginal_slope_score_warp = marginal_slope_deviation_routing.score_warp;
let marginal_slope_link_dev = marginal_slope_deviation_routing.link_dev;
let crossfit_calibration = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
crossfit_score_calibration(data, col_map, config.ctn_stage1.as_ref(), &policy)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })?
} else {
None
};
let (marginal_z, marginal_slope_jac_oof) = match (marginal_z, crossfit_calibration) {
(Some(mut z_surfaces), Some(calibration)) => {
if z_surfaces.ncols() != 1 {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"cross-fitted score calibration applies to a single CTN-generated z \
surface, but the survival marginal-slope model has {} z surfaces; \
multi-surface logslope is incompatible with the CTN Stage-1 chain",
z_surfaces.ncols()
),
});
}
z_surfaces.column_mut(0).assign(&calibration.z_oof);
(Some(z_surfaces), Some(calibration.jac_oof))
}
(z, _) => (z, None),
};
if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
if parsed.linkwiggle.is_some() {
inference_notes.push(
"survival marginal-slope routes formula-level linkwiggle(...) into its anchored internal link-deviation block while keeping the probit survival base link".to_string(),
);
}
if marginal_slope_score_warp.is_some() {
inference_notes.push(
"survival marginal-slope routes logslope_formula linkwiggle(...) into its anchored internal score-warp block while keeping the probit survival base link".to_string(),
);
}
if marginal_slope_link_dev.is_none() && marginal_slope_score_warp.is_none() {
inference_notes.push(
"survival marginal-slope rigid mode is algebraic closed-form exact".to_string(),
);
} else {
inference_notes.push(
"survival marginal-slope flexible score/link mode uses calibrated de-nested cubic transport cells with analytic value evaluation and calibrated survival normalization"
.to_string(),
);
}
}
let marginal_slope_frailty = if survival_mode == SurvivalLikelihoodMode::MarginalSlope {
Some(fixed_gaussian_shift_frailty_from_spec(
config.frailty.as_ref().unwrap_or(&FrailtySpec::None),
"survival marginal-slope",
)?)
} else {
None
};
match survival_mode {
SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
if config.frailty.is_some() =>
{
return Err(WorkflowError::InvalidConfig {
reason: "frailty is not supported for transformation/weibull survival models"
.to_string(),
}
.into());
}
SurvivalLikelihoodMode::LocationScale if config.frailty.is_some() => {
return Err(WorkflowError::InvalidConfig {
reason: "config.frailty is not implemented for survival-likelihood=location-scale"
.to_string(),
}
.into());
}
SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
if effective_timewiggle.is_some() =>
{
return Err(WorkflowError::InvalidConfig {
reason: "timewiggle is not implemented for latent survival/binary likelihoods"
.to_string(),
}
.into());
}
_ => {}
}
let latent_loading = if matches!(
survival_mode,
SurvivalLikelihoodMode::Latent | SurvivalLikelihoodMode::LatentBinary
) {
let frailty = config.frailty.as_ref().unwrap_or(&FrailtySpec::None);
Some(latent_hazard_loading(
frailty,
"workflow latent survival/binary",
)?)
} else {
None
};
let build_time_block =
|candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
let prepared = prepare_survival_time_stack(
&age_entry,
&age_exit,
candidate,
survival_mode,
(survival_mode == SurvivalLikelihoodMode::LocationScale)
.then_some(&survival_inverse_link),
time_anchor,
exact_derivative_guard,
&time_build,
effective_timewiggle.as_ref(),
None,
)?;
let time_p = prepared.time_design_exit.ncols();
let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
None
} else {
Some(Array1::from_elem(
prepared.time_penalties.len(),
config.time_smooth_lambda.ln(),
))
};
let initial_beta = if survival_mode == SurvivalLikelihoodMode::LocationScale {
None
} else {
Some(Array1::from_elem(time_p, 1e-4))
};
let time_block = TimeBlockInput {
design_entry: prepared.time_design_entry.clone(),
design_exit: prepared.time_design_exit.clone(),
design_derivative_exit: prepared.time_design_derivative_exit.clone(),
offset_entry: prepared.eta_offset_entry.clone(),
offset_exit: prepared.eta_offset_exit.clone(),
derivative_offset_exit: prepared.derivative_offset_exit.clone(),
time_monotonicity: crate::families::survival_location_scale::TimeBlockMonotonicity::EnforcedByCoordinateCone,
penalties: prepared.time_penalties.clone(),
nullspace_dims: prepared.time_nullspace_dims.clone(),
initial_log_lambdas: time_initial_log_lambdas,
initial_beta,
};
Ok::<_, String>((prepared, time_block))
};
let location_scale_smoothing_warm_start: RefCell<Option<(Array1<f64>, Array1<f64>)>> =
RefCell::new(None);
let build_location_scale_request =
|candidate: &crate::families::survival_construction::SurvivalBaselineConfig,
allow_inverse_link_optimization: bool| {
let (prepared, time_block) = build_time_block(candidate)?;
let (initial_threshold_log_lambdas, initial_log_sigma_log_lambdas) =
match location_scale_smoothing_warm_start.borrow().as_ref() {
Some((thr, lsg)) => (Some(thr.clone()), Some(lsg.clone())),
None => (None, None),
};
let spec = SurvivalLocationScaleTermSpec {
age_entry: age_entry.clone(),
age_exit: age_exit.clone(),
event_target: event.clone(),
weights: weights.clone(),
inverse_link: survival_inverse_link.clone(),
derivative_guard: exact_derivative_guard,
max_iter: 200,
tol: 1e-7,
time_block,
thresholdspec: termspec.clone(),
log_sigmaspec: log_sigmaspec.clone(),
threshold_offset: threshold_offset.clone(),
log_sigma_offset: log_sigma_offset.clone(),
threshold_template: threshold_template.clone(),
log_sigma_template: log_sigma_template.clone(),
timewiggle_block: prepared.timewiggle_block,
linkwiggle_block: None,
initial_threshold_log_lambdas,
initial_log_sigma_log_lambdas,
cache_session: None,
cache_mirror_sessions: Vec::new(),
};
let optimize_inverse_link = allow_inverse_link_optimization
&& survival_inverse_link_has_free_parameters(&spec.inverse_link);
Ok::<_, String>(SurvivalLocationScaleFitRequest {
data: data.values.view(),
spec,
wiggle: effective_linkwiggle_cfg.clone(),
kappa_options: SpatialLengthScaleOptimizationOptions::default(),
optimize_inverse_link,
cache_session: None,
})
};
let build_marginal_slope_request =
|candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
let (prepared, mut time_block) = build_time_block(candidate)?;
time_block.time_monotonicity =
crate::families::survival_location_scale::TimeBlockMonotonicity::EnforcedByRowConstraint;
Ok::<_, String>(SurvivalMarginalSlopeFitRequest {
data: data.values.view(),
spec: SurvivalMarginalSlopeTermSpec {
age_entry: age_entry.clone(),
age_exit: age_exit.clone(),
event_target: event.clone(),
weights: weights.clone(),
z: marginal_z.clone().ok_or_else(|| {
"marginal-slope survival requires z_column in FitConfig".to_string()
})?,
base_link: marginal_slope_base_link.clone().ok_or_else(|| {
"internal error: marginal-slope base link validation missing".to_string()
})?,
marginalspec: termspec.clone(),
marginal_offset: threshold_offset.clone(),
frailty: marginal_slope_frailty.clone().ok_or_else(|| {
"internal error: marginal-slope frailty validation missing".to_string()
})?,
derivative_guard: exact_derivative_guard,
time_block,
timewiggle_block: prepared.timewiggle_block,
logslopespec: marginal_logslopespec.clone().ok_or_else(|| {
"marginal-slope survival is missing logslope spec".to_string()
})?,
logslopespecs: marginal_logslopespecs.clone(),
logslope_offset: log_sigma_offset.clone(),
score_warp: marginal_slope_score_warp.clone(),
link_dev: marginal_slope_link_dev.clone(),
latent_z_policy: Default::default(),
score_influence_jacobian: marginal_slope_jac_oof.clone(),
},
options: BlockwiseFitOptions {
compute_covariance: false,
..Default::default()
},
kappa_options: SpatialLengthScaleOptimizationOptions::default(),
})
};
let build_latent_survival_request =
|candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
let loading = latent_loading.ok_or_else(|| {
"internal error: latent survival loading missing after frailty validation"
.to_string()
})?;
let prepared = prepare_survival_time_stack(
&age_entry,
&age_exit,
candidate,
survival_mode,
None,
time_anchor,
exact_derivative_guard,
&time_build,
None,
Some(loading),
)?;
let (time_design_right, time_offset_right, unloaded_mass_right, event_target) =
if let (Some(age_right), Some(time_build_right)) =
(age_right.as_ref(), time_build_right.as_ref())
{
let prepared_right = prepare_survival_time_stack(
&age_entry,
age_right,
candidate,
survival_mode,
None,
time_anchor,
exact_derivative_guard,
time_build_right,
None,
Some(loading),
)?;
if prepared_right.time_design_exit.ncols() != prepared.time_design_exit.ncols()
{
return Err(format!(
"interval-censored right time design has {} columns but the left/exit design has {}; the right boundary basis must share the exit basis columns",
prepared_right.time_design_exit.ncols(),
prepared.time_design_exit.ncols()
));
}
let event_target = event.mapv(|v| {
if v >= 0.5 {
crate::families::latent_survival::LATENT_SURVIVAL_EVENT_INTERVAL
} else {
0
}
});
(
Some(prepared_right.time_design_exit.clone()),
Some(prepared_right.eta_offset_exit.clone()),
prepared_right.unloaded_mass_exit.clone(),
event_target,
)
} else {
(
None,
None,
Array1::zeros(0),
event.mapv(|v| if v >= 0.5 { 1 } else { 0 }),
)
};
let time_p = prepared.time_design_exit.ncols();
let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
None
} else {
Some(Array1::from_elem(
prepared.time_penalties.len(),
config.time_smooth_lambda.ln(),
))
};
let time_block = TimeBlockInput {
design_entry: prepared.time_design_entry.clone(),
design_exit: prepared.time_design_exit.clone(),
design_derivative_exit: prepared.time_design_derivative_exit.clone(),
offset_entry: prepared.eta_offset_entry.clone(),
offset_exit: prepared.eta_offset_exit.clone(),
derivative_offset_exit: prepared.derivative_offset_exit.clone(),
time_monotonicity: crate::families::survival_location_scale::TimeBlockMonotonicity::EnforcedByCoordinateCone,
penalties: prepared.time_penalties.clone(),
nullspace_dims: prepared.time_nullspace_dims.clone(),
initial_log_lambdas: time_initial_log_lambdas,
initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
};
Ok::<_, String>(LatentSurvivalFitRequest {
data: data.values.view(),
spec: LatentSurvivalTermSpec {
age_entry: age_entry.clone(),
age_exit: age_exit.clone(),
event_target,
weights: weights.clone(),
derivative_guard: exact_derivative_guard,
time_block,
time_design_right,
time_offset_right,
unloaded_mass_entry: prepared.unloaded_mass_entry,
unloaded_mass_exit: prepared.unloaded_mass_exit,
unloaded_mass_right,
unloaded_hazard_exit: prepared.unloaded_hazard_exit,
meanspec: termspec.clone(),
mean_offset: threshold_offset.clone(),
},
frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
options: BlockwiseFitOptions::default(),
})
};
let build_latent_binary_request =
|candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
let loading = latent_loading.ok_or_else(|| {
"internal error: latent binary loading missing after frailty validation".to_string()
})?;
let prepared = prepare_survival_time_stack(
&age_entry,
&age_exit,
candidate,
survival_mode,
None,
time_anchor,
exact_derivative_guard,
&time_build,
None,
Some(loading),
)?;
let time_p = prepared.time_design_exit.ncols();
let time_initial_log_lambdas = if prepared.time_penalties.is_empty() {
None
} else {
Some(Array1::from_elem(
prepared.time_penalties.len(),
config.time_smooth_lambda.ln(),
))
};
let time_block = TimeBlockInput {
design_entry: prepared.time_design_entry.clone(),
design_exit: prepared.time_design_exit.clone(),
design_derivative_exit: prepared.time_design_derivative_exit.clone(),
offset_entry: prepared.eta_offset_entry.clone(),
offset_exit: prepared.eta_offset_exit.clone(),
derivative_offset_exit: prepared.derivative_offset_exit.clone(),
time_monotonicity: crate::families::survival_location_scale::TimeBlockMonotonicity::EnforcedByCoordinateCone,
penalties: prepared.time_penalties.clone(),
nullspace_dims: prepared.time_nullspace_dims.clone(),
initial_log_lambdas: time_initial_log_lambdas,
initial_beta: Some(Array1::from_elem(time_p, 1e-4)),
};
Ok::<_, String>(LatentBinaryFitRequest {
data: data.values.view(),
spec: LatentBinaryTermSpec {
age_entry: age_entry.clone(),
age_exit: age_exit.clone(),
event_target: event.mapv(|v| if v >= 0.5 { 1 } else { 0 }),
weights: weights.clone(),
derivative_guard: exact_derivative_guard,
time_block,
unloaded_mass_entry: prepared.unloaded_mass_entry,
unloaded_mass_exit: prepared.unloaded_mass_exit,
meanspec: termspec.clone(),
mean_offset: threshold_offset.clone(),
},
frailty: config.frailty.clone().unwrap_or(FrailtySpec::None),
options: BlockwiseFitOptions::default(),
})
};
let baseline_cfg = if matches!(
survival_mode,
SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull
) {
baseline_cfg
} else if baseline_cfg.target != SurvivalBaselineTarget::Linear
&& survival_mode == SurvivalLikelihoodMode::MarginalSlope
{
optimize_survival_baseline_config_with_gradient(
&baseline_cfg,
"workflow survival marginal-slope baseline",
|candidate| {
let fit =
fit_survival_marginal_slope_model(build_marginal_slope_request(candidate)?)
.map_err(|e| format!("survival marginal-slope fit failed: {e}"))?;
let gradient = marginal_slope_baseline_chain_rule_gradient(
age_entry.view(),
age_exit.view(),
candidate,
&fit.baseline_offset_residuals,
)?
.ok_or_else(|| {
"workflow survival marginal-slope baseline unexpectedly has no theta gradient"
.to_string()
})?;
let hessian = marginal_slope_baseline_chain_rule_hessian(
age_entry.view(),
age_exit.view(),
candidate,
&fit.baseline_offset_residuals,
&fit.baseline_offset_curvatures,
)?
.ok_or_else(|| {
"workflow survival marginal-slope baseline unexpectedly has no theta Hessian"
.to_string()
})?;
Ok((fit.fit.reml_score, gradient, hessian))
},
)?
} else if baseline_cfg.target != SurvivalBaselineTarget::Linear
&& survival_mode == SurvivalLikelihoodMode::LocationScale
{
let probit_channel =
location_scale_uses_probit_survival_baseline(Some(&survival_inverse_link));
let baseline_outcome = optimize_survival_baseline_config_with_gradient_only(
&baseline_cfg,
"workflow survival location-scale baseline",
|candidate| {
let fit_result = fit_survival_location_scale_model(build_location_scale_request(
candidate, false,
)?)
.map_err(|e| format!("survival location-scale fit failed: {e}"))?;
let threshold_rho = fit_result.fit.fit.lambdas_threshold().mapv(f64::ln);
let log_sigma_rho = fit_result.fit.fit.lambdas_log_sigma().mapv(f64::ln);
*location_scale_smoothing_warm_start.borrow_mut() =
Some((threshold_rho, log_sigma_rho));
let residuals = &fit_result.fit.baseline_offset_residuals;
let gradient = if probit_channel {
marginal_slope_baseline_chain_rule_gradient(
age_entry.view(),
age_exit.view(),
candidate,
residuals,
)?
} else {
baseline_chain_rule_gradient(
age_entry.view(),
age_exit.view(),
age_exit.view(),
candidate,
residuals,
)?
}
.ok_or_else(|| {
"workflow survival location-scale baseline unexpectedly has no theta gradient"
.to_string()
})?;
let profile_cost = -fit_result.fit.fit.log_likelihood
+ 0.5 * fit_result.fit.fit.stable_penalty_term;
if !profile_cost.is_finite() {
return Err(format!(
"workflow survival location-scale baseline: non-finite profile cost \
(log_likelihood={}, stable_penalty_term={}, cost={})",
fit_result.fit.fit.log_likelihood,
fit_result.fit.fit.stable_penalty_term,
profile_cost
));
}
Ok((profile_cost, gradient))
},
);
match baseline_outcome {
Ok(baseline) => baseline,
Err(e)
if e.contains("expects 3 blocks, got 0")
|| e.contains("expects 4 blocks, got 0")
|| (e.contains("block_states") && e.contains("got 0"))
|| e.contains("blockwise fit requires at least one block state")
|| e.contains(SURVIVAL_LOCATION_SCALE_EMPTY_BLOCK_STATES_MARKER) =>
{
log::warn!(
"workflow survival location-scale baseline: gradient-only BFGS \
failed at an empty-block_states candidate ({e}); falling back \
to the seed baseline_cfg as-is"
);
baseline_cfg.clone()
}
Err(e) => return Err(e.into()),
}
} else if baseline_cfg.target != SurvivalBaselineTarget::Linear {
let baseline_outcome = optimize_survival_baseline_config_with_gradient_only(
&baseline_cfg,
"workflow latent survival baseline",
|candidate| {
let (log_likelihood, stable_penalty_term, residuals) = match survival_mode {
SurvivalLikelihoodMode::Latent => {
let request = build_latent_survival_request(candidate)?;
match fit_model(FitRequest::LatentSurvival(request)) {
Ok(FitResult::LatentSurvival(result)) => (
result.fit.log_likelihood,
result.fit.stable_penalty_term,
result.baseline_offset_residuals,
),
Ok(_) => {
return Err("internal latent survival workflow returned the wrong result variant".to_string());
}
Err(e) => return Err(format!("latent survival fit failed: {e}")),
}
}
SurvivalLikelihoodMode::LatentBinary => {
let request = build_latent_binary_request(candidate)?;
match fit_model(FitRequest::LatentBinary(request)) {
Ok(FitResult::LatentBinary(result)) => (
result.fit.log_likelihood,
result.fit.stable_penalty_term,
result.baseline_offset_residuals,
),
Ok(_) => {
return Err("internal latent binary workflow returned the wrong result variant".to_string());
}
Err(e) => return Err(format!("latent binary fit failed: {e}")),
}
}
SurvivalLikelihoodMode::Transformation
| SurvivalLikelihoodMode::Weibull
| SurvivalLikelihoodMode::LocationScale
| SurvivalLikelihoodMode::MarginalSlope => {
return Err(format!(
"internal: workflow latent baseline closure reached for non-latent mode {survival_mode:?}"
));
}
};
let profile_cost = -log_likelihood + 0.5 * stable_penalty_term;
if !profile_cost.is_finite() {
return Err(format!(
"workflow latent baseline: non-finite profile cost \
(log_likelihood={log_likelihood}, \
stable_penalty_term={stable_penalty_term}, cost={profile_cost})"
));
}
let age_right_view = age_right.as_ref().unwrap_or(&age_exit);
let gradient = baseline_chain_rule_gradient(
age_entry.view(),
age_exit.view(),
age_right_view.view(),
candidate,
&residuals,
)?
.ok_or_else(|| {
"workflow latent baseline unexpectedly has no theta gradient".to_string()
})?;
Ok((profile_cost, gradient))
},
);
match baseline_outcome {
Ok(baseline) => baseline,
Err(e)
if e.contains("expects 3 blocks, got 0")
|| e.contains("expects 4 blocks, got 0")
|| (e.contains("block_states") && e.contains("got 0"))
|| e.contains("blockwise fit requires at least one block state")
|| e.contains(SURVIVAL_LOCATION_SCALE_EMPTY_BLOCK_STATES_MARKER) =>
{
log::warn!(
"workflow latent survival baseline: gradient-only BFGS failed at an \
empty-block_states candidate ({e}); falling back to the seed \
baseline_cfg as-is"
);
baseline_cfg.clone()
}
Err(e) => return Err(WorkflowError::InvalidConfig { reason: e }.into()),
}
} else {
baseline_cfg
};
let request = match survival_mode {
SurvivalLikelihoodMode::Transformation | SurvivalLikelihoodMode::Weibull => {
if config.noise_offset_column.is_some() {
return Err(WorkflowError::InvalidConfig {
reason:
"noise_offset_column is supported only for survival location-scale or marginal-slope"
.to_string(),
}
.into());
}
let weibull_seed = if survival_mode == SurvivalLikelihoodMode::Weibull
&& effective_timewiggle.is_none()
{
let scale = config
.baseline_scale
.unwrap_or_else(|| positive_survival_time_seed(&age_exit));
let shape = config.baseline_shape.unwrap_or(1.0);
if !scale.is_finite() || scale <= 0.0 || !shape.is_finite() || shape <= 0.0 {
return Err(WorkflowError::InvalidConfig {
reason:
"weibull survival fit requires finite positive baseline_scale and baseline_shape"
.to_string(),
}
.into());
}
Some((scale, shape))
} else {
None
};
FitRequest::SurvivalTransformation(SurvivalTransformationFitRequest {
data: data.values.view(),
spec: SurvivalTransformationTermSpec {
age_entry: age_entry.clone(),
age_exit: age_exit.clone(),
event_target: event_codes.clone(),
weights: weights.clone(),
covariate_spec: termspec.clone(),
covariate_offset: threshold_offset.clone(),
baseline_cfg,
likelihood_mode: survival_mode,
time_anchor,
time_build: time_build.clone(),
timewiggle: effective_timewiggle.clone(),
weibull_seed,
ridge_lambda: config.ridge_lambda,
penalty_block_gamma_priors: config.penalty_block_gamma_priors.clone(),
},
cache_session: None,
})
}
SurvivalLikelihoodMode::LocationScale => {
FitRequest::SurvivalLocationScale(build_location_scale_request(&baseline_cfg, true)?)
}
SurvivalLikelihoodMode::MarginalSlope => {
FitRequest::SurvivalMarginalSlope(build_marginal_slope_request(&baseline_cfg)?)
}
SurvivalLikelihoodMode::Latent => {
FitRequest::LatentSurvival(build_latent_survival_request(&baseline_cfg)?)
}
SurvivalLikelihoodMode::LatentBinary => {
FitRequest::LatentBinary(build_latent_binary_request(&baseline_cfg)?)
}
};
Ok(MaterializedModel {
request,
inference_notes,
})
}
pub(crate) fn materialize_transformation_normal<'a>(
parsed: &ParsedFormula,
data: &'a Dataset,
col_map: &HashMap<String, usize>,
config: &FitConfig,
) -> Result<MaterializedModel<'a>, WorkflowError> {
if parsed.linkspec.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "link(...) is not supported for the transformation-normal family".to_string(),
}
.into());
}
if parsed.linkwiggle.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "linkwiggle(...) is not supported for the transformation-normal family"
.to_string(),
}
.into());
}
if config.noise_offset_column.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "noise_offset_column is not supported for transformation-normal models"
.to_string(),
}
.into());
}
if config.frailty.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "frailty is not supported for transformation-normal models".to_string(),
}
.into());
}
let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
let y = data.values.column(y_col).to_owned();
let mut inference_notes = Vec::new();
let policy = resolved_resource_policy(config, data, marginal_slope_hints(config));
let covariate_spec = build_termspec_with_geometry_and_overrides(
&parsed.terms,
data,
col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
let offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
Ok(MaterializedModel {
request: FitRequest::TransformationNormal(TransformationNormalFitRequest {
data: data.values.view(),
response: y,
weights,
offset,
covariate_spec,
config: TransformationNormalConfig::default(),
options: BlockwiseFitOptions::default(),
kappa_options: SpatialLengthScaleOptimizationOptions::default(),
warm_start: None,
}),
inference_notes,
})
}
fn apply_secondary_predictor_basis_parsimony(terms: &mut [ParsedTerm], n_rows: usize) {
for term in terms.iter_mut() {
if let ParsedTerm::Smooth {
vars,
kind,
options,
..
} = term
{
let canonical = resolve_smooth_type_name(*kind, vars.len(), options);
if !smooth_type_uses_spatial_center_heuristic(&canonical)
|| has_explicit_countwith_basis_alias(options, "centers")
{
continue;
}
let cap = crate::terms::basis::conservative_secondary_centers(n_rows, vars.len());
options.insert(SECONDARY_CENTER_CAP_OPTION.to_string(), cap.to_string());
}
}
}
pub(crate) fn materialize_location_scale<'a>(
parsed: &ParsedFormula,
data: &'a Dataset,
col_map: &HashMap<String, usize>,
config: &FitConfig,
) -> Result<MaterializedModel<'a>, WorkflowError> {
let y_col = resolve_role_col(col_map, &parsed.response, "response")?;
let y = data.values.column(y_col).to_owned();
let y_kind = response_column_kind(data, y_col);
let mut inference_notes = Vec::new();
let noise_formula = config
.noise_formula
.as_deref()
.ok_or_else(|| "noise_formula is required for location-scale models".to_string())?;
let mut noise_parsed = parse_formula(&format!("{} ~ {noise_formula}", parsed.response))?;
apply_secondary_predictor_basis_parsimony(&mut noise_parsed.terms, data.values.nrows());
let link_choice = parse_link_choice(config.link.as_deref(), config.flexible_link)?;
let family = resolve_family(
config.family.as_deref(),
config.negative_binomial_theta,
link_choice.as_ref(),
y.view(),
y_kind,
&parsed.response,
)?;
family
.response
.validate_response_support(y.view())
.map_err(|violation| violation.message_for(&parsed.response))?;
family
.response
.validate_response_degeneracy(y.view())
.map_err(|deg| deg.message_for(&parsed.response))?;
reject_explicit_linkwiggle_for_nonbinomial(parsed, &family)?;
let effective_linkwiggle =
effectivelinkwiggle_formulaspec(parsed.linkwiggle.as_ref(), link_choice.as_ref());
let policy = resolved_resource_policy(config, data, crate::resource::ProblemHints::default());
let meanspec = build_termspec_with_geometry_and_overrides(
&parsed.terms,
data,
col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
let log_sigmaspec = build_termspec_with_geometry_and_overrides(
&noise_parsed.terms,
data,
col_map,
&mut inference_notes,
config.scale_dimensions,
&policy,
config.smooth_overrides.as_ref(),
)?;
check_smooth_capacity(&meanspec, y.len(), &parsed.response)?;
check_smooth_capacity(&log_sigmaspec, y.len(), &parsed.response)?;
let weights = resolve_weight_column(data, col_map, config.weight_column.as_deref())?;
let mean_offset = resolve_offset_column(data, col_map, config.offset_column.as_deref())?;
let noise_offset = resolve_offset_column(data, col_map, config.noise_offset_column.as_deref())?;
let kappa_options = SpatialLengthScaleOptimizationOptions::default();
let options = BlockwiseFitOptions::default();
let wiggle_cfg = effective_linkwiggle.map(|cfg| LinkWiggleConfig {
degree: cfg.degree,
num_internal_knots: cfg.num_internal_knots,
penalty_orders: cfg.penalty_orders,
double_penalty: cfg.double_penalty,
});
if family.is_latent_cloglog() {
return Err(WorkflowError::InvalidConfig {
reason: "latent-cloglog-binomial is not implemented for location-scale fitting"
.to_string(),
}
.into());
}
if family.is_binomial() {
let link_kind = match link_choice.as_ref() {
Some(c) => match StandardLink::try_from(c.link) {
Ok(std_link) => InverseLink::Standard(std_link),
Err(e) => {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"binomial location-scale fitting cannot route link `{}` through `InverseLink::Standard`: {e}",
c.link.name()
),
}
.into());
}
},
None => InverseLink::Standard(StandardLink::Logit),
};
Ok(MaterializedModel {
request: FitRequest::BinomialLocationScale(BinomialLocationScaleFitRequest {
data: data.values.view(),
spec: BinomialLocationScaleTermSpec {
y,
weights,
link_kind,
thresholdspec: meanspec,
log_sigmaspec,
threshold_offset: mean_offset,
log_sigma_offset: noise_offset,
},
wiggle: wiggle_cfg,
options,
kappa_options,
}),
inference_notes,
})
} else if let Some(kind) = dispersion_location_scale_kind(&family.response) {
if wiggle_cfg.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"link-wiggle is not supported for {} location-scale models",
kind.family_tag()
),
}
.into());
}
Ok(MaterializedModel {
request: FitRequest::DispersionLocationScale(DispersionLocationScaleFitRequest {
data: data.values.view(),
spec: DispersionGlmLocationScaleTermSpec {
kind,
y,
weights,
meanspec,
log_dispspec: log_sigmaspec,
mean_offset,
log_disp_offset: noise_offset,
},
options,
kappa_options,
}),
inference_notes,
})
} else {
Ok(MaterializedModel {
request: FitRequest::GaussianLocationScale(GaussianLocationScaleFitRequest {
data: data.values.view(),
spec: GaussianLocationScaleTermSpec {
y,
weights,
meanspec,
log_sigmaspec,
mean_offset,
log_sigma_offset: noise_offset,
},
wiggle: wiggle_cfg,
options,
kappa_options,
}),
inference_notes,
})
}
}
fn dispersion_location_scale_kind(response: &ResponseFamily) -> Option<DispersionFamilyKind> {
match response {
ResponseFamily::NegativeBinomial { .. } => Some(DispersionFamilyKind::NegativeBinomial),
ResponseFamily::Gamma => Some(DispersionFamilyKind::Gamma),
ResponseFamily::Beta { .. } => Some(DispersionFamilyKind::Beta),
ResponseFamily::Tweedie { p } => Some(DispersionFamilyKind::Tweedie { p: *p }),
_ => None,
}
}
#[cfg(test)]
mod tests;