trait WorkflowCauseCountResult {
fn into_workflow_result(self) -> Result<usize, String>;
}
impl WorkflowCauseCountResult for usize {
fn into_workflow_result(self) -> Result<usize, String> {
Ok(self)
}
}
impl<E: ToString> WorkflowCauseCountResult for Result<usize, E> {
fn into_workflow_result(self) -> Result<usize, String> {
self.map_err(|err| err.to_string())
}
}
#[derive(Debug, Clone)]
pub enum WorkflowError {
InvalidConfig { reason: String },
SchemaMismatch { reason: String },
MissingDependency { reason: String },
IntegrationFailed { reason: String },
FormulaDsl {
context: &'static str,
source: crate::inference::formula_dsl::FormulaDslError,
},
ColumnNotFound {
name: String,
role: Option<String>,
available: Vec<String>,
similar: Vec<String>,
tsv_hint: bool,
},
}
impl std::fmt::Display for WorkflowError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WorkflowError::InvalidConfig { reason }
| WorkflowError::SchemaMismatch { reason }
| WorkflowError::MissingDependency { reason }
| WorkflowError::IntegrationFailed { reason } => f.write_str(reason),
WorkflowError::FormulaDsl { context, source } => write!(f, "{context}: {source}"),
WorkflowError::ColumnNotFound {
name,
role,
available,
similar,
tsv_hint,
} => {
let label = match role {
Some(r) => format!("{r} column '{name}'"),
None => format!("column '{name}'"),
};
let tsv_suffix = if *tsv_hint {
" — your file appears to be tab-separated; gam expects comma-separated CSV. \
Replace tabs with commas, or pre-convert with `tr '\\t' ',' < file.tsv > file.csv`."
} else {
""
};
if similar.is_empty() {
write!(
f,
"{label} not found in data. Available columns: [{}]{tsv_suffix}",
available.join(", ")
)
} else {
write!(
f,
"{label} not found in data. Did you mean one of [{}]? Full list: [{}]{tsv_suffix}",
similar.join(", "),
available.join(", ")
)
}
}
}
}
}
impl std::error::Error for WorkflowError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
WorkflowError::FormulaDsl { source, .. } => Some(source),
WorkflowError::InvalidConfig { .. }
| WorkflowError::SchemaMismatch { .. }
| WorkflowError::MissingDependency { .. }
| WorkflowError::IntegrationFailed { .. }
| WorkflowError::ColumnNotFound { .. } => None,
}
}
}
impl From<WorkflowError> for String {
fn from(err: WorkflowError) -> String {
err.to_string()
}
}
impl From<String> for WorkflowError {
fn from(reason: String) -> Self {
Self::InvalidConfig { reason }
}
}
impl From<&str> for WorkflowError {
fn from(reason: &str) -> Self {
Self::InvalidConfig {
reason: reason.to_string(),
}
}
}
impl From<crate::inference::formula_dsl::FormulaDslError> for WorkflowError {
fn from(err: crate::inference::formula_dsl::FormulaDslError) -> Self {
Self::FormulaDsl {
context: "workflow formula materialization",
source: err,
}
}
}
impl From<crate::terms::term_builder::TermBuilderError> for WorkflowError {
fn from(err: crate::terms::term_builder::TermBuilderError) -> Self {
use crate::terms::term_builder::TermBuilderError;
match err {
TermBuilderError::ColumnNotFound {
name,
role,
available,
similar,
tsv_hint,
} => Self::ColumnNotFound {
name,
role,
available,
similar,
tsv_hint,
},
TermBuilderError::MissingColumn { reason }
| TermBuilderError::MalformedFormula { reason } => Self::SchemaMismatch { reason },
TermBuilderError::IncompatibleConfig { reason }
| TermBuilderError::InvalidOption { reason }
| TermBuilderError::UnsupportedFeature { reason }
| TermBuilderError::DegenerateData { reason } => Self::InvalidConfig { reason },
}
}
}
impl From<crate::inference::data::DataError> for WorkflowError {
fn from(err: crate::inference::data::DataError) -> Self {
use crate::inference::data::DataError;
match err {
DataError::ColumnNotFound {
name,
role,
available,
similar,
tsv_hint,
} => Self::ColumnNotFound {
name,
role,
available,
similar,
tsv_hint,
},
DataError::SchemaMismatch { reason } => Self::SchemaMismatch { reason },
DataError::ParseError { reason }
| DataError::EncodingFailure { reason }
| DataError::EmptyInput { reason }
| DataError::InvalidValue { reason } => Self::InvalidConfig { reason },
}
}
}
#[derive(Clone, Debug)]
pub struct LinkWiggleConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
}
#[derive(Clone)]
pub struct StandardBinomialWiggleConfig {
pub link_kind: InverseLink,
pub wiggle: LinkWiggleConfig,
pub refit_options: BlockwiseFitOptions,
}
pub struct StandardFitRequest<'a> {
pub data: Array2<f64>,
pub y: Array1<f64>,
pub weights: Array1<f64>,
pub offset: Array1<f64>,
pub spec: TermCollectionSpec,
pub family: LikelihoodSpec,
pub options: FitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub wiggle: Option<StandardBinomialWiggleConfig>,
pub coefficient_groups: Vec<CoefficientGroupSpec>,
pub penalty_block_gamma_priors: Vec<(String, f64, f64)>,
pub latent_coord: Option<StandardLatentCoordConfig>,
#[doc(hidden)]
pub _marker: std::marker::PhantomData<&'a ()>,
}
pub struct GaussianLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: GaussianLocationScaleTermSpec,
pub wiggle: Option<LinkWiggleConfig>,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
pub struct BinomialLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: BinomialLocationScaleTermSpec,
pub wiggle: Option<LinkWiggleConfig>,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
pub struct DispersionLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: DispersionGlmLocationScaleTermSpec,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
pub struct SurvivalLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: SurvivalLocationScaleTermSpec,
pub wiggle: Option<LinkWiggleConfig>,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub optimize_inverse_link: bool,
pub cache_session: Option<std::sync::Arc<crate::cache::Session>>,
}
pub struct SurvivalTransformationFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: SurvivalTransformationTermSpec,
pub cache_session: Option<std::sync::Arc<crate::cache::Session>>,
}
#[derive(Clone)]
pub struct SurvivalTransformationTermSpec {
pub age_entry: Array1<f64>,
pub age_exit: Array1<f64>,
pub event_target: Array1<u8>,
pub weights: Array1<f64>,
pub covariate_spec: TermCollectionSpec,
pub covariate_offset: Array1<f64>,
pub baseline_cfg: crate::families::survival_construction::SurvivalBaselineConfig,
pub likelihood_mode: crate::families::survival_construction::SurvivalLikelihoodMode,
pub time_anchor: f64,
pub time_build: crate::families::survival_construction::SurvivalTimeBuildOutput,
pub timewiggle: Option<LinkWiggleFormulaSpec>,
pub weibull_seed: Option<(f64, f64)>,
pub ridge_lambda: f64,
pub penalty_block_gamma_priors: Vec<(String, f64, f64)>,
}
pub(crate) fn survival_inverse_link_has_free_parameters(link: &InverseLink) -> bool {
match link {
InverseLink::Sas(_) | InverseLink::BetaLogistic(_) => true,
InverseLink::Mixture(state) => !state.rho.is_empty(),
InverseLink::LatentCLogLog(_) | InverseLink::Standard(_) => false,
}
}
fn recover_converged_survival_inverse_link<R>(
result: crate::solver::outer_strategy::OuterResult,
context: &str,
recover: R,
) -> Result<InverseLink, String>
where
R: FnOnce(&Array1<f64>) -> Option<InverseLink>,
{
if !result.converged {
return Err(WorkflowError::IntegrationFailed {
reason: format!(
"{context} did not converge after {} iterations (final_objective={:.6e}, final_grad_norm={})",
result.iterations,
result.final_value,
result.final_grad_norm_report(),
),
}
.into());
}
recover(&result.rho).ok_or_else(|| {
format!(
"{context} produced an invalid inverse-link state at rho={:?}",
result.rho.to_vec()
)
})
}
pub struct BernoulliMarginalSlopeFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: BernoulliMarginalSlopeTermSpec,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub policy: crate::resource::ResourcePolicy,
}
pub struct SurvivalMarginalSlopeFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: SurvivalMarginalSlopeTermSpec,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
const LOG_LAMBDA_UNDERFLOW_FLOOR: f64 = 1e-300;
const SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS: usize = 400;
const SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL: f64 = 1e-6;
const SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING: usize = 40;
const SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE: f64 = 1e-12;
pub struct LatentSurvivalFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: LatentSurvivalTermSpec,
pub frailty: FrailtySpec,
pub options: BlockwiseFitOptions,
}
pub struct LatentBinaryFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: LatentBinaryTermSpec,
pub frailty: FrailtySpec,
pub options: BlockwiseFitOptions,
}
pub struct TransformationNormalFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub response: Array1<f64>,
pub weights: Array1<f64>,
pub offset: Array1<f64>,
pub covariate_spec: TermCollectionSpec,
pub config: TransformationNormalConfig,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub warm_start: Option<TransformationWarmStart>,
}
pub enum FitRequest<'a> {
Standard(StandardFitRequest<'a>),
GaussianLocationScale(GaussianLocationScaleFitRequest<'a>),
BinomialLocationScale(BinomialLocationScaleFitRequest<'a>),
DispersionLocationScale(DispersionLocationScaleFitRequest<'a>),
SurvivalLocationScale(SurvivalLocationScaleFitRequest<'a>),
SurvivalTransformation(SurvivalTransformationFitRequest<'a>),
BernoulliMarginalSlope(BernoulliMarginalSlopeFitRequest<'a>),
SurvivalMarginalSlope(SurvivalMarginalSlopeFitRequest<'a>),
LatentSurvival(LatentSurvivalFitRequest<'a>),
LatentBinary(LatentBinaryFitRequest<'a>),
TransformationNormal(TransformationNormalFitRequest<'a>),
}
pub trait FamilyFitRequest {
const TAG: &'static str;
fn tag(&self) -> &'static str {
Self::TAG
}
fn n_obs(&self) -> usize;
fn n_cols(&self) -> usize;
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter);
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter);
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>);
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>);
}
macro_rules! family_dispatch {
($scrutinee:expr, $req:ident => $body:expr) => {
match $scrutinee {
FitRequest::Standard($req) => $body,
FitRequest::GaussianLocationScale($req) => $body,
FitRequest::BinomialLocationScale($req) => $body,
FitRequest::DispersionLocationScale($req) => $body,
FitRequest::SurvivalLocationScale($req) => $body,
FitRequest::SurvivalTransformation($req) => $body,
FitRequest::BernoulliMarginalSlope($req) => $body,
FitRequest::SurvivalMarginalSlope($req) => $body,
FitRequest::LatentSurvival($req) => $body,
FitRequest::LatentBinary($req) => $body,
FitRequest::TransformationNormal($req) => $body,
}
};
}
impl<'a> FamilyFitRequest for StandardFitRequest<'a> {
const TAG: &'static str = "standard";
fn n_obs(&self) -> usize {
self.y.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("standard");
h.write_str(&format!("{:?}", self.family));
h.write_usize(self.y.len());
h.write_usize(self.data.ncols());
self.spec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("standard-seed");
h.write_str(&format!("{:?}", self.family));
h.write_usize(self.data.ncols());
self.spec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
drop(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
drop(mirror);
}
}
impl<'a> FamilyFitRequest for GaussianLocationScaleFitRequest<'a> {
const TAG: &'static str = "gaussian-location-scale";
fn n_obs(&self) -> usize {
self.spec.y.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("gauss-ls");
h.write_usize(self.spec.y.len());
h.write_usize(self.data.ncols());
self.spec.meanspec.write_structural_shape_hash(h);
self.spec.log_sigmaspec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("gauss-ls-seed");
h.write_usize(self.data.ncols());
self.spec.meanspec.write_structural_shape_hash(h);
self.spec.log_sigmaspec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for BinomialLocationScaleFitRequest<'a> {
const TAG: &'static str = "binomial-location-scale";
fn n_obs(&self) -> usize {
self.spec.y.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("binom-ls");
h.write_usize(self.spec.y.len());
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.link_kind));
self.spec.thresholdspec.write_structural_shape_hash(h);
self.spec.log_sigmaspec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("binom-ls-seed");
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.link_kind));
self.spec.thresholdspec.write_structural_shape_hash(h);
self.spec.log_sigmaspec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for DispersionLocationScaleFitRequest<'a> {
const TAG: &'static str = "dispersion-location-scale";
fn n_obs(&self) -> usize {
self.spec.y.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("disp-ls");
h.write_str(self.spec.kind.family_tag());
h.write_usize(self.spec.y.len());
h.write_usize(self.data.ncols());
self.spec.meanspec.write_structural_shape_hash(h);
self.spec.log_dispspec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("disp-ls-seed");
h.write_str(self.spec.kind.family_tag());
h.write_usize(self.data.ncols());
self.spec.meanspec.write_structural_shape_hash(h);
self.spec.log_dispspec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for SurvivalLocationScaleFitRequest<'a> {
const TAG: &'static str = "survival-location-scale";
fn n_obs(&self) -> usize {
self.spec.age_entry.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("surv-ls");
h.write_usize(self.spec.age_entry.len());
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.inverse_link));
self.spec.thresholdspec.write_structural_shape_hash(h);
self.spec.log_sigmaspec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("surv-ls-seed");
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.inverse_link));
self.spec.thresholdspec.write_structural_shape_hash(h);
self.spec.log_sigmaspec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
if self.cache_session.is_none() {
self.cache_session = Some(session.clone());
}
if self.spec.cache_session.is_none() {
self.spec.cache_session = Some(session);
}
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.spec.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for SurvivalTransformationFitRequest<'a> {
const TAG: &'static str = "survival-transformation";
fn n_obs(&self) -> usize {
self.spec.age_entry.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("surv-tn");
h.write_usize(self.spec.age_entry.len());
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.likelihood_mode));
h.write_str(&self.spec.time_build.basisname);
self.spec.covariate_spec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("surv-tn-seed");
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.likelihood_mode));
h.write_str(&self.spec.time_build.basisname);
self.spec.covariate_spec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
drop(mirror);
}
}
impl<'a> FamilyFitRequest for BernoulliMarginalSlopeFitRequest<'a> {
const TAG: &'static str = "bernoulli-marginal-slope";
fn n_obs(&self) -> usize {
self.spec.y.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("bern-ms");
h.write_usize(self.spec.y.len());
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.base_link));
self.spec.marginalspec.write_structural_shape_hash(h);
self.spec.logslopespec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("bern-ms-seed");
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.base_link));
self.spec.marginalspec.write_structural_shape_hash(h);
self.spec.logslopespec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for SurvivalMarginalSlopeFitRequest<'a> {
const TAG: &'static str = "survival-marginal-slope";
fn n_obs(&self) -> usize {
self.spec.age_entry.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("surv-ms");
h.write_usize(self.spec.age_entry.len());
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.base_link));
h.write_str(&format!("{:?}", self.spec.frailty));
self.spec.marginalspec.write_structural_shape_hash(h);
self.spec.logslopespec.write_structural_shape_hash(h);
match self.spec.logslopespecs.as_ref() {
Some(specs) => {
h.write_bool(true);
h.write_usize(specs.len());
for spec in specs {
spec.write_structural_shape_hash(h);
}
}
None => h.write_bool(false),
}
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("surv-ms-seed");
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.spec.base_link));
h.write_str(&format!("{:?}", self.spec.frailty));
self.spec.marginalspec.write_structural_shape_hash(h);
self.spec.logslopespec.write_structural_shape_hash(h);
match self.spec.logslopespecs.as_ref() {
Some(specs) => {
h.write_bool(true);
h.write_usize(specs.len());
for spec in specs {
spec.write_structural_shape_hash(h);
}
}
None => h.write_bool(false),
}
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for LatentSurvivalFitRequest<'a> {
const TAG: &'static str = "latent-survival";
fn n_obs(&self) -> usize {
self.spec.age_entry.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("lat-surv");
h.write_usize(self.spec.age_entry.len());
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.frailty));
self.spec.meanspec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("lat-surv-seed");
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.frailty));
self.spec.meanspec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for LatentBinaryFitRequest<'a> {
const TAG: &'static str = "latent-binary";
fn n_obs(&self) -> usize {
self.spec.age_entry.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("lat-bin");
h.write_usize(self.spec.age_entry.len());
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.frailty));
self.spec.meanspec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("lat-bin-seed");
h.write_usize(self.data.ncols());
h.write_str(&format!("{:?}", self.frailty));
self.spec.meanspec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FamilyFitRequest for TransformationNormalFitRequest<'a> {
const TAG: &'static str = "transformation-normal";
fn n_obs(&self) -> usize {
self.response.len()
}
fn n_cols(&self) -> usize {
self.data.ncols()
}
fn write_shape_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("tn");
h.write_usize(self.response.len());
h.write_usize(self.data.ncols());
self.covariate_spec.write_structural_shape_hash(h);
}
fn write_seed_hash(&self, h: &mut crate::cache::Fingerprinter) {
h.write_str("tn-seed");
h.write_usize(self.data.ncols());
self.covariate_spec.write_structural_shape_hash(h);
}
fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
self.options.cache_session.get_or_insert(session);
}
fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
self.options.cache_mirror_sessions.push(mirror);
}
}
impl<'a> FitRequest<'a> {
pub fn family_tag(&self) -> &'static str {
family_dispatch!(self, r => r.tag())
}
pub fn cache_key(&self) -> String {
let mut shape = crate::cache::Fingerprinter::new();
family_dispatch!(self, r => r.write_shape_hash(&mut shape));
let shape_hash = shape.finish_hex();
let (nrows, ncols) = family_dispatch!(self, r => (r.n_obs(), r.n_cols()));
format!(
"{}/family={}/dims={}x{}/shape={}",
crate::solver::persistent_warm_start::cache_schema_tag(),
self.family_tag(),
nrows,
ncols,
shape_hash,
)
}
pub fn cache_seed_key(&self) -> String {
let mut shape = crate::cache::Fingerprinter::new();
family_dispatch!(self, r => r.write_seed_hash(&mut shape));
format!(
"{}/family={}/seed/{}",
crate::solver::persistent_warm_start::cache_schema_tag(),
self.family_tag(),
shape.finish_hex(),
)
}
pub fn attach_cache_mirror(&mut self, mirror: std::sync::Arc<crate::cache::Session>) {
family_dispatch!(self, r => <_ as FamilyFitRequest>::attach_cache_mirror(r, mirror))
}
pub fn attach_cache_session(&mut self, session: std::sync::Arc<crate::cache::Session>) {
family_dispatch!(self, r => <_ as FamilyFitRequest>::attach_cache_session(r, session))
}
}
pub struct StandardFitResult {
pub fit: UnifiedFitResult,
pub design: TermCollectionDesign,
pub resolvedspec: TermCollectionSpec,
pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
pub saved_link_state: FittedLinkState,
pub wiggle_knots: Option<Array1<f64>>,
pub wiggle_degree: Option<usize>,
}
pub struct SurvivalLocationScaleFitResult {
pub fit: SurvivalLocationScaleTermFitResult,
pub inverse_link: InverseLink,
pub wiggle_knots: Option<Array1<f64>>,
pub wiggle_degree: Option<usize>,
}
pub struct SurvivalTransformationFitResult {
pub fit: UnifiedFitResult,
pub resolvedspec: TermCollectionSpec,
pub baseline_cfg: crate::families::survival_construction::SurvivalBaselineConfig,
pub likelihood_mode: crate::families::survival_construction::SurvivalLikelihoodMode,
pub time_basis: crate::families::survival_construction::SavedSurvivalTimeBasis,
pub time_base_ncols: usize,
pub baseline_timewiggle: Option<TimeWiggleBlockInput>,
}
struct SurvivalLocationScaleProfile {
fit: SurvivalLocationScaleTermFitResult,
inverse_link: InverseLink,
wiggle_knots: Option<Array1<f64>>,
wiggle_degree: Option<usize>,
}
impl SurvivalLocationScaleProfile {
fn into_result(self) -> SurvivalLocationScaleFitResult {
SurvivalLocationScaleFitResult {
fit: self.fit,
inverse_link: self.inverse_link,
wiggle_knots: self.wiggle_knots,
wiggle_degree: self.wiggle_degree,
}
}
}
pub enum FitResult {
Standard(StandardFitResult),
GaussianLocationScale(GaussianLocationScaleFitResult),
BinomialLocationScale(BinomialLocationScaleFitResult),
DispersionLocationScale(DispersionLocationScaleFitResult),
SurvivalLocationScale(SurvivalLocationScaleFitResult),
SurvivalTransformation(SurvivalTransformationFitResult),
BernoulliMarginalSlope(BernoulliMarginalSlopeFitResult),
SurvivalMarginalSlope(SurvivalMarginalSlopeFitResult),
LatentSurvival(LatentSurvivalTermFitResult),
LatentBinary(LatentBinaryTermFitResult),
TransformationNormal(TransformationNormalFitResult),
SplineScan(crate::solver::spline_scan::SplineScanFit),
ResidualCascade(crate::solver::residual_cascade::ResidualCascadeFit),
}
pub struct DispersionLocationScaleFitResult {
pub fit: BlockwiseTermFitResult,
pub kind: DispersionFamilyKind,
}
fn resolved_wiggle_inverse_link(
spec: &LikelihoodSpec,
fit: &UnifiedFitResult,
fallback: &InverseLink,
) -> Result<InverseLink, String> {
let resolved = match fit.fitted_link_state(spec).map_err(|e| e.to_string())? {
FittedLinkState::Standard(Some(link)) => InverseLink::Standard(link),
FittedLinkState::Standard(None) => fallback.clone(),
FittedLinkState::LatentCLogLog { state } => InverseLink::LatentCLogLog(state),
FittedLinkState::Sas { state, .. } => InverseLink::Sas(state),
FittedLinkState::BetaLogistic { state, .. } => InverseLink::BetaLogistic(state),
FittedLinkState::Mixture { state, .. } => InverseLink::Mixture(state),
};
require_inverse_link_supports_joint_wiggle(&resolved, "standard link wiggle")?;
Ok(resolved)
}
fn deviation_block_config_from_formula_linkwiggle(
wiggle: &LinkWiggleFormulaSpec,
) -> DeviationBlockConfig {
let defaults = WigglePenaltyConfig::cubic_triple_operator_default();
DeviationBlockConfig {
degree: wiggle.degree,
num_internal_knots: wiggle.num_internal_knots,
penalty_order: *wiggle.penalty_orders.iter().max().unwrap_or(&2),
penalty_orders: wiggle.penalty_orders.clone(),
double_penalty: wiggle.double_penalty,
monotonicity_eps: defaults.monotonicity_eps,
}
}
struct MarginalSlopeDeviationRouting {
score_warp: Option<DeviationBlockConfig>,
link_dev: Option<DeviationBlockConfig>,
}
fn route_marginal_slope_deviation_blocks(
main_linkwiggle: Option<&LinkWiggleFormulaSpec>,
logslope_linkwiggle: Option<&LinkWiggleFormulaSpec>,
) -> Result<MarginalSlopeDeviationRouting, String> {
Ok(MarginalSlopeDeviationRouting {
score_warp: logslope_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
link_dev: main_linkwiggle.map(deviation_block_config_from_formula_linkwiggle),
})
}
fn fixed_gaussian_shift_frailty_from_spec(
frailty: &FrailtySpec,
context: &str,
) -> Result<FrailtySpec, String> {
match frailty {
FrailtySpec::None => Ok(FrailtySpec::None),
FrailtySpec::GaussianShift {
sigma_fixed: Some(sigma),
} => Ok(FrailtySpec::GaussianShift {
sigma_fixed: Some(*sigma),
}),
FrailtySpec::GaussianShift { sigma_fixed: None } => Err(WorkflowError::MissingDependency {
reason: format!("{context} currently requires a fixed GaussianShift sigma"),
}
.into()),
FrailtySpec::HazardMultiplier { .. } => Err(WorkflowError::MissingDependency {
reason: format!("{context} requires FrailtySpec::GaussianShift or no frailty"),
}
.into()),
}
}
fn fit_standard_model(request: StandardFitRequest<'_>) -> Result<StandardFitResult, String> {
let fitted = if let Some(latent_coord) = request.latent_coord.as_ref() {
if !request.coefficient_groups.is_empty() || !request.penalty_block_gamma_priors.is_empty()
{
return Err("latent-coordinate standard fits do not support coefficient_groups or penalty_block_gamma_priors in the same request".to_string());
}
fit_term_collectionwith_latent_coord_optimization(
request.data.view(),
request.y.clone(),
request.weights.clone(),
request.offset.clone(),
&request.spec,
latent_coord,
request.family.clone(),
&request.options,
)
.map_err(|e| e.to_string())?
} else if !request.coefficient_groups.is_empty()
|| !request.penalty_block_gamma_priors.is_empty()
{
let fitted = fit_term_collection_with_coefficient_groups_and_penalty_block_gamma_priors(
request.data.view(),
request.y.view(),
request.weights.view(),
request.offset.view(),
&request.spec,
&request.coefficient_groups,
&request.penalty_block_gamma_priors,
request.family.clone(),
&request.options,
)
.map_err(|e| e.to_string())?;
crate::terms::smooth::FittedTermCollectionWithSpec {
fit: fitted.fit,
design: fitted.design,
resolvedspec: request.spec.clone(),
adaptive_diagnostics: fitted.adaptive_diagnostics,
}
} else {
fit_term_collectionwith_spatial_length_scale_optimization(
request.data.view(),
request.y.clone(),
request.weights.clone(),
request.offset.clone(),
&request.spec,
request.family.clone(),
&request.options,
&request.kappa_options,
)
.map_err(|e| e.to_string())?
};
let result = StandardFitResult {
saved_link_state: fitted.fit.fitted_link.clone(),
fit: fitted.fit,
design: fitted.design,
resolvedspec: fitted.resolvedspec,
adaptive_diagnostics: fitted.adaptive_diagnostics,
wiggle_knots: None,
wiggle_degree: None,
};
let Some(wiggle) = request.wiggle else {
return Ok(result);
};
let wiggle_options = wiggle.refit_options.clone();
let wiggle_link_kind =
resolved_wiggle_inverse_link(&request.family, &result.fit, &wiggle.link_kind)?;
let selected_wiggle_basis = select_binomial_mean_link_wiggle_basis_from_pilot(
&result.design,
&result.fit,
&WiggleBlockConfig {
degree: wiggle.wiggle.degree,
num_internal_knots: wiggle.wiggle.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle.wiggle.double_penalty,
},
&wiggle.wiggle.penalty_orders,
)?;
let solved = match fit_binomial_mean_wiggle_terms_with_selected_basis(
request.data.view(),
&result.resolvedspec,
&result.design,
&result.fit,
&request.y,
&request.weights,
wiggle_link_kind,
selected_wiggle_basis,
&wiggle_options,
&request.kappa_options,
) {
Ok(solved) => solved,
Err(e) => {
log::warn!(
"[linkwiggle] binomial mean link-wiggle joint solve did not converge ({e}); \
falling back to the no-wiggle baseline fit (the large-smoothing limit of the \
penalized wiggle model, which contains it as a limiting case)"
);
return Ok(result);
}
};
Ok(StandardFitResult {
saved_link_state: result.saved_link_state,
fit: solved.fit,
design: solved.design,
resolvedspec: solved.resolvedspec,
adaptive_diagnostics: result.adaptive_diagnostics,
wiggle_knots: Some(solved.wiggle_knots),
wiggle_degree: Some(solved.wiggle_degree),
})
}
struct LocationScaleWorkflowParts<'a, S> {
data: ArrayView2<'a, f64>,
spec: S,
wiggle: Option<LinkWiggleConfig>,
options: BlockwiseFitOptions,
kappa_options: SpatialLengthScaleOptimizationOptions,
}
trait LocationScaleWorkflowAdapter {
type Spec;
type Request<'a>;
type Result;
fn into_parts<'a>(request: Self::Request<'a>) -> LocationScaleWorkflowParts<'a, Self::Spec>;
fn fit_pilot(
data: ArrayView2<'_, f64>,
spec: &Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String>;
fn refit_with_selected_wiggle(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
pilot: &BlockwiseTermFitResult,
wiggle_cfg: &LinkWiggleConfig,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermWiggleFitResult, String>;
fn fit_plain(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String>;
fn assemble_plain(fit: BlockwiseTermFitResult) -> Self::Result;
fn assemble_with_wiggle(
fit: BlockwiseTermFitResult,
wiggle_knots: Array1<f64>,
wiggle_degree: usize,
beta_link_wiggle: Option<Vec<f64>>,
) -> Self::Result;
}
fn fit_location_scale_with_optional_wiggle<A: LocationScaleWorkflowAdapter>(
request: A::Request<'_>,
) -> Result<A::Result, String> {
let LocationScaleWorkflowParts {
data,
spec,
wiggle,
options,
kappa_options,
} = A::into_parts(request);
let Some(wiggle_cfg) = wiggle else {
let fit = A::fit_plain(data, spec, &options, &kappa_options)?;
return Ok(A::assemble_plain(fit));
};
let pilot = A::fit_pilot(data, &spec, &options, &kappa_options)?;
let solved =
A::refit_with_selected_wiggle(data, spec, &pilot, &wiggle_cfg, &options, &kappa_options)?;
let fit = solved.fit.fit;
let beta_link_wiggle = fit.block_states.get(2).map(|b| b.beta.to_vec());
let assembled_fit = BlockwiseTermFitResult::try_from_parts(BlockwiseTermFitResultParts {
fit,
meanspec_resolved: solved.fit.meanspec_resolved,
noisespec_resolved: solved.fit.noisespec_resolved,
mean_design: solved.fit.mean_design,
noise_design: solved.fit.noise_design,
})?;
Ok(A::assemble_with_wiggle(
assembled_fit,
solved.wiggle_knots,
solved.wiggle_degree,
beta_link_wiggle,
))
}
struct GaussianLocationScaleWorkflow;
impl LocationScaleWorkflowAdapter for GaussianLocationScaleWorkflow {
type Spec = GaussianLocationScaleTermSpec;
type Request<'a> = GaussianLocationScaleFitRequest<'a>;
type Result = GaussianLocationScaleFitResult;
fn into_parts<'a>(request: Self::Request<'a>) -> LocationScaleWorkflowParts<'a, Self::Spec> {
LocationScaleWorkflowParts {
data: request.data,
spec: request.spec,
wiggle: request.wiggle,
options: request.options,
kappa_options: request.kappa_options,
}
}
fn fit_pilot(
data: ArrayView2<'_, f64>,
spec: &Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
fit_gaussian_location_scale_terms(
data,
GaussianLocationScaleTermSpec {
y: spec.y.clone(),
weights: spec.weights.clone(),
meanspec: spec.meanspec.clone(),
log_sigmaspec: spec.log_sigmaspec.clone(),
mean_offset: spec.mean_offset.clone(),
log_sigma_offset: spec.log_sigma_offset.clone(),
},
options,
kappa_options,
)
}
fn refit_with_selected_wiggle(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
pilot: &BlockwiseTermFitResult,
wiggle_cfg: &LinkWiggleConfig,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermWiggleFitResult, String> {
let selected_wiggle_basis = select_gaussian_location_scale_link_wiggle_basis_from_pilot(
pilot,
&WiggleBlockConfig {
degree: wiggle_cfg.degree,
num_internal_knots: wiggle_cfg.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle_cfg.double_penalty,
},
&wiggle_cfg.penalty_orders,
)?;
fit_gaussian_location_scale_terms_with_selected_wiggle(
data,
spec,
selected_wiggle_basis,
options,
kappa_options,
)
}
fn fit_plain(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
fit_gaussian_location_scale_terms(data, spec, options, kappa_options)
}
fn assemble_plain(fit: BlockwiseTermFitResult) -> Self::Result {
GaussianLocationScaleFitResult {
fit,
wiggle_knots: None,
wiggle_degree: None,
beta_link_wiggle: None,
response_scale: 1.0,
}
}
fn assemble_with_wiggle(
fit: BlockwiseTermFitResult,
wiggle_knots: Array1<f64>,
wiggle_degree: usize,
beta_link_wiggle: Option<Vec<f64>>,
) -> Self::Result {
GaussianLocationScaleFitResult {
fit,
wiggle_knots: Some(wiggle_knots),
wiggle_degree: Some(wiggle_degree),
beta_link_wiggle,
response_scale: 1.0,
}
}
}
struct BinomialLocationScaleWorkflow;
impl LocationScaleWorkflowAdapter for BinomialLocationScaleWorkflow {
type Spec = BinomialLocationScaleTermSpec;
type Request<'a> = BinomialLocationScaleFitRequest<'a>;
type Result = BinomialLocationScaleFitResult;
fn into_parts<'a>(request: Self::Request<'a>) -> LocationScaleWorkflowParts<'a, Self::Spec> {
LocationScaleWorkflowParts {
data: request.data,
spec: request.spec,
wiggle: request.wiggle,
options: request.options,
kappa_options: request.kappa_options,
}
}
fn fit_pilot(
data: ArrayView2<'_, f64>,
spec: &Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
require_inverse_link_supports_joint_wiggle(
&spec.link_kind,
"binomial location-scale link wiggle",
)?;
fit_binomial_location_scale_terms(
data,
BinomialLocationScaleTermSpec {
y: spec.y.clone(),
weights: spec.weights.clone(),
link_kind: spec.link_kind.clone(),
thresholdspec: spec.thresholdspec.clone(),
log_sigmaspec: spec.log_sigmaspec.clone(),
threshold_offset: spec.threshold_offset.clone(),
log_sigma_offset: spec.log_sigma_offset.clone(),
},
options,
kappa_options,
)
}
fn refit_with_selected_wiggle(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
pilot: &BlockwiseTermFitResult,
wiggle_cfg: &LinkWiggleConfig,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermWiggleFitResult, String> {
let selected_wiggle_basis = select_binomial_location_scale_link_wiggle_basis_from_pilot(
pilot,
&WiggleBlockConfig {
degree: wiggle_cfg.degree,
num_internal_knots: wiggle_cfg.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle_cfg.double_penalty,
},
&wiggle_cfg.penalty_orders,
)?;
fit_binomial_location_scale_terms_with_selected_wiggle(
data,
spec,
selected_wiggle_basis,
options,
kappa_options,
)
}
fn fit_plain(
data: ArrayView2<'_, f64>,
spec: Self::Spec,
options: &BlockwiseFitOptions,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<BlockwiseTermFitResult, String> {
fit_binomial_location_scale_terms(data, spec, options, kappa_options)
}
fn assemble_plain(fit: BlockwiseTermFitResult) -> Self::Result {
BinomialLocationScaleFitResult {
fit,
wiggle_knots: None,
wiggle_degree: None,
beta_link_wiggle: None,
}
}
fn assemble_with_wiggle(
fit: BlockwiseTermFitResult,
wiggle_knots: Array1<f64>,
wiggle_degree: usize,
beta_link_wiggle: Option<Vec<f64>>,
) -> Self::Result {
BinomialLocationScaleFitResult {
fit,
wiggle_knots: Some(wiggle_knots),
wiggle_degree: Some(wiggle_degree),
beta_link_wiggle,
}
}
}
fn gaussian_response_sample_std(v: ArrayView1<'_, f64>) -> f64 {
if v.is_empty() {
return 0.0;
}
let n = v.len() as f64;
let mean = v.iter().copied().sum::<f64>() / n;
let var = v
.iter()
.copied()
.map(|x| {
let d = x - mean;
d * d
})
.sum::<f64>()
/ n.max(1.0);
var.max(0.0).sqrt()
}
fn rescale_gaussian_location_scale_to_raw(
result: &mut GaussianLocationScaleFitResult,
response_scale: f64,
) {
use crate::estimate::BlockRole;
let s = response_scale;
let ln_s = s.ln();
let scale_intercept_range = result.fit.noise_design.intercept_range.clone();
let mut joint_offset = 0usize;
for (block_idx, block) in result.fit.fit.blocks.iter_mut().enumerate() {
let block_len = block.beta.len();
match block.role {
BlockRole::Mean | BlockRole::Location | BlockRole::LinkWiggle => {
block.beta.mapv_inplace(|v| v * s);
if result.fit.fit.beta.len() >= joint_offset + block_len {
for i in 0..block_len {
result.fit.fit.beta[joint_offset + i] *= s;
}
}
if let Some(state) = result.fit.fit.block_states.get_mut(block_idx) {
state.beta.mapv_inplace(|v| v * s);
}
}
BlockRole::Scale => {
for col in scale_intercept_range.clone() {
if col < block.beta.len() {
block.beta[col] += ln_s;
}
let joint_col = joint_offset + col;
if joint_col < result.fit.fit.beta.len() {
result.fit.fit.beta[joint_col] += ln_s;
}
if let Some(state) = result.fit.fit.block_states.get_mut(block_idx)
&& col < state.beta.len()
{
state.beta[col] += ln_s;
}
}
}
BlockRole::Time | BlockRole::Threshold => {
}
}
joint_offset += block_len;
}
if let Some(knots) = result.wiggle_knots.as_mut() {
knots.mapv_inplace(|v| v * s);
}
if let Some(beta_w) = result.beta_link_wiggle.as_mut() {
for coef in beta_w.iter_mut() {
*coef *= s;
}
}
let mut row_factors: Vec<f64> = Vec::new();
for block in &result.fit.fit.blocks {
let f = match block.role {
BlockRole::Mean | BlockRole::Location | BlockRole::LinkWiggle => s,
BlockRole::Scale | BlockRole::Time | BlockRole::Threshold => 1.0,
};
row_factors.extend(std::iter::repeat_n(f, block.beta.len()));
}
let rescale_cov = |cov: &mut Array2<f64>| {
let m = cov.nrows().min(cov.ncols()).min(row_factors.len());
for i in 0..m {
for j in 0..m {
cov[[i, j]] *= row_factors[i] * row_factors[j];
}
}
};
if let Some(cov) = result.fit.fit.covariance_conditional.as_mut() {
rescale_cov(cov);
}
if let Some(cov) = result.fit.fit.covariance_corrected.as_mut() {
rescale_cov(cov);
}
result.fit.fit.standard_deviation *= s;
result.fit.fit.max_abs_eta *= s;
if let Some(n_obs) = result
.fit
.fit
.block_states
.first()
.map(|state| state.eta.len() as f64)
.filter(|&n| n > 0.0)
{
let ln_s = s.ln();
result.fit.fit.log_likelihood -= n_obs * ln_s;
result.fit.fit.deviance += 2.0 * n_obs * ln_s;
result.fit.fit.reml_score += n_obs * ln_s;
result.fit.fit.penalized_objective += n_obs * ln_s;
}
result.response_scale = s;
}
fn fit_gaussian_location_scale_model(
mut request: GaussianLocationScaleFitRequest<'_>,
) -> Result<GaussianLocationScaleFitResult, String> {
let response_scale = gaussian_response_sample_std(request.spec.y.view()).max(1e-6);
if response_scale != 1.0 {
request.spec.y.mapv_inplace(|v| v / response_scale);
request
.spec
.mean_offset
.mapv_inplace(|v| v / response_scale);
}
let mut result =
fit_location_scale_with_optional_wiggle::<GaussianLocationScaleWorkflow>(request)?;
rescale_gaussian_location_scale_to_raw(&mut result, response_scale);
Ok(result)
}
fn fit_dispersion_location_scale_model(
request: DispersionLocationScaleFitRequest<'_>,
) -> Result<DispersionLocationScaleFitResult, String> {
let kind = request.spec.kind;
let fit = fit_dispersion_glm_location_scale_terms(
request.data,
request.spec,
&request.options,
&request.kappa_options,
)?;
Ok(DispersionLocationScaleFitResult { fit, kind })
}
fn fit_binomial_location_scale_model(
request: BinomialLocationScaleFitRequest<'_>,
) -> Result<BinomialLocationScaleFitResult, String> {
fit_location_scale_with_optional_wiggle::<BinomialLocationScaleWorkflow>(request)
}
fn survival_working_reml_score(state: &crate::pirls::WorkingState) -> f64 {
0.5 * (state.deviance + state.penalty_term)
}
fn fitted_weibull_baseline_from_linear_time_beta(
beta: &Array1<f64>,
anchor: f64,
) -> Option<crate::families::survival_construction::SurvivalBaselineConfig> {
if beta.len() < 2 {
return None;
}
let shape = beta[1];
if !shape.is_finite() || shape <= 0.0 {
return None;
}
if !anchor.is_finite() || anchor <= 0.0 {
return None;
}
let scale = anchor;
Some(
crate::families::survival_construction::SurvivalBaselineConfig {
target: SurvivalBaselineTarget::Weibull,
scale: Some(scale),
shape: Some(shape),
rate: None,
makeham: None,
},
)
}
fn survival_transformation_edf(
state: &crate::pirls::WorkingState,
penalty_blocks: &[PenaltyBlock],
) -> Result<(f64, Vec<f64>, Array2<f64>), String> {
let h_dense = state.hessian.to_dense();
let p = h_dense.nrows();
let h_sym = crate::linalg::matrix::SymmetricMatrix::Dense(h_dense.clone());
let factor = {
let scale = h_sym.max_abs_diag();
let min_step = scale * 1e-10;
let mut ridge = 0.0_f64;
let mut attempts = 0_usize;
loop {
let candidate = if ridge > 0.0 {
h_sym.addridge(ridge).unwrap_or_else(|_| h_sym.clone())
} else {
h_sym.clone()
};
if let Ok(f) = candidate.factorize() {
break f;
}
attempts += 1;
if attempts >= 8 {
return Err("survival edf: penalized Hessian could not be factorized".to_string());
}
ridge = if ridge <= 0.0 { min_step } else { ridge * 10.0 };
}
};
let mut edf_by_block = vec![0.0_f64; penalty_blocks.len()];
let mut total_trace = 0.0_f64;
for (kk, block) in penalty_blocks.iter().enumerate() {
let block_cols = block.range.end - block.range.start;
if block.lambda <= 0.0 || block_cols == 0 {
edf_by_block[kk] = block_cols as f64;
continue;
}
let mut rhs = Array2::<f64>::zeros((p, block_cols));
for c in 0..block_cols {
for r in 0..block_cols {
rhs[[block.range.start + r, c]] = block.matrix[[r, c]];
}
}
let sol = factor
.solvemulti(&rhs)
.map_err(|e| format!("survival edf trace solve failed: {e}"))?;
let mut trace = 0.0_f64;
for j in 0..block_cols {
trace += sol[[block.range.start + j, j]];
}
let lam_trace = block.lambda * trace;
total_trace += lam_trace;
edf_by_block[kk] = (block_cols as f64 - lam_trace).clamp(0.0, block_cols as f64);
}
let edf_total = (p as f64 - total_trace).clamp(0.0, p as f64);
if !edf_total.is_finite() || edf_by_block.iter().any(|v| !v.is_finite()) {
return Err("survival edf: non-finite effective degrees of freedom".to_string());
}
Ok((edf_total, edf_by_block, h_dense))
}
fn optimize_survival_transformation_smoothing(
model: &crate::families::survival::WorkingModelSurvival,
penalty_blocks: &[PenaltyBlock],
num_smoothing: usize,
beta0: &Array1<f64>,
structural_lower_bounds: Option<&Array1<f64>>,
) -> Result<Option<Vec<f64>>, String> {
use crate::solver::outer_strategy::{Derivative, HessianResult, OuterEval, OuterProblem};
if num_smoothing == 0 {
return Ok(None);
}
let seed_lambdas: Vec<f64> = penalty_blocks.iter().map(|b| b.lambda).collect();
let seed_rho = Array1::from_iter(
seed_lambdas
.iter()
.take(num_smoothing)
.map(|&l| l.max(1e-12).ln()),
);
let eval_at = |rho_smooth: &Array1<f64>| -> Result<(f64, Array1<f64>), String> {
let mut candidate = model.clone();
let mut lambdas = seed_lambdas.clone();
for k in 0..num_smoothing {
lambdas[k] = rho_smooth[k].exp();
}
candidate
.set_penalty_lambdas(&lambdas)
.map_err(|e| e.to_string())?;
let opts = crate::pirls::WorkingModelPirlsOptions {
max_iterations: SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS,
convergence_tolerance: SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL,
adaptive_kkt_tolerance: None,
max_step_halving: SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING,
min_step_size: SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE,
firth_bias_reduction: false,
coefficient_lower_bounds: structural_lower_bounds.cloned(),
linear_constraints: None,
initial_lm_lambda: None,
geodesic_acceleration: false,
arrow_schur: None,
};
let summary = crate::pirls::runworking_model_pirls(
&mut candidate,
crate::types::Coefficients::new(beta0.clone()),
&opts,
|_| {},
)
.map_err(|err| format!("survival smoothing PIRLS failed: {err}"))?;
let beta = summary.beta.as_ref().to_owned();
let state = candidate
.update_state(&beta)
.map_err(|err| format!("survival smoothing state eval failed: {err}"))?;
let full_rho = Array1::from_iter(lambdas.iter().filter(|&&l| l > 0.0).map(|&l| l.ln()));
let (cost, grad_full) = candidate
.unified_lamlobjective_and_rhogradient(&beta, &state, &full_rho)
.map_err(|err| format!("survival LAML evaluation failed: {err}"))?;
if grad_full.len() < num_smoothing || !cost.is_finite() {
return Err("survival LAML returned an inconsistent gradient/cost".to_string());
}
let grad = grad_full.slice(s![..num_smoothing]).to_owned();
if grad.iter().any(|g| !g.is_finite()) {
return Err("survival LAML gradient is non-finite".to_string());
}
Ok((cost, grad))
};
let lower = seed_rho.mapv(|v| v - 12.0);
let upper = seed_rho.mapv(|v| v + 12.0);
let problem = OuterProblem::new(num_smoothing)
.with_gradient(Derivative::Analytic)
.with_hessian(crate::solver::outer_strategy::DeclaredHessianForm::Unavailable)
.with_tolerance(1e-4)
.with_max_iter(120)
.with_bounds(lower, upper)
.with_initial_rho(seed_rho.clone())
.with_seed_config(crate::seeding::SeedConfig {
max_seeds: 1,
seed_budget: 1,
..Default::default()
});
let context =
format!("survival transformation smoothing-parameter selection (dim={num_smoothing})");
let mut obj = problem.build_objective(
(),
|_: &mut (), rho: &Array1<f64>| {
eval_at(rho)
.map(|(c, _)| c)
.map_err(crate::estimate::EstimationError::InvalidInput)
},
|_: &mut (), rho: &Array1<f64>| {
let (cost, gradient) =
eval_at(rho).map_err(crate::estimate::EstimationError::InvalidInput)?;
Ok(OuterEval {
cost,
gradient,
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
})
},
None::<fn(&mut ())>,
None::<
fn(
&mut (),
&Array1<f64>,
)
-> Result<crate::solver::outer_strategy::EfsEval, crate::estimate::EstimationError>,
>,
);
let result = problem
.run(&mut obj, &context)
.map_err(|err| format!("{context} failed: {err}"))?;
let selected_rho = result.rho;
let mut lambdas = seed_lambdas;
for k in 0..num_smoothing.min(selected_rho.len()) {
let lam = selected_rho[k].exp();
if lam.is_finite() && lam > 0.0 {
lambdas[k] = lam;
}
}
Ok(Some(lambdas))
}
fn survival_unified_fit_result(
beta: Array1<f64>,
lambdas: Array1<f64>,
summary: &crate::pirls::WorkingModelPirlsResult,
state: &crate::pirls::WorkingState,
penalty_blocks: &[PenaltyBlock],
) -> Result<UnifiedFitResult, String> {
let log_lambdas = lambdas.mapv(|v| v.max(LOG_LAMBDA_UNDERFLOW_FLOOR).ln());
let reml_score = survival_working_reml_score(state);
crate::estimate::validate_all_finite("survival fit beta", beta.iter().copied())?;
crate::estimate::validate_all_finite("survival fit lambdas", lambdas.iter().copied())?;
crate::estimate::ensure_finite_scalar("survival fit log_likelihood", state.log_likelihood)?;
crate::estimate::ensure_finite_scalar("survival fit deviance", state.deviance)?;
crate::estimate::ensure_finite_scalar("survival fit penalty", state.penalty_term)?;
crate::estimate::ensure_finite_scalar("survival fit reml_score", reml_score)?;
crate::estimate::ensure_finite_scalar("survival fit gradient_norm", summary.lastgradient_norm)?;
crate::estimate::ensure_finite_scalar("survival fit max_abs_eta", summary.max_abs_eta)?;
let (edf_total, edf_by_block, penalized_hessian) =
survival_transformation_edf(state, penalty_blocks)?;
assert_eq!(edf_by_block.len(), lambdas.len());
let inference = crate::estimate::FitInference {
edf_by_block: edf_by_block.clone(),
edf_total,
smoothing_correction: None,
penalized_hessian: penalized_hessian.into(),
working_weights: Array1::zeros(0),
working_response: Array1::zeros(0),
reparam_qs: None,
dispersion: crate::estimate::Dispersion::Known(1.0),
beta_covariance: None,
beta_standard_errors: None,
beta_covariance_corrected: None,
beta_standard_errors_corrected: None,
beta_covariance_frequentist: None,
coefficient_influence: None,
weighted_gram: None,
bias_correction_beta: None,
};
UnifiedFitResult::try_from_parts(crate::estimate::UnifiedFitResultParts {
blocks: vec![crate::estimate::FittedBlock {
beta: beta.clone(),
role: crate::estimate::BlockRole::Mean,
edf: edf_total,
lambdas: lambdas.clone(),
}],
log_lambdas,
lambdas,
likelihood_family: Some(LikelihoodSpec::royston_parmar()),
likelihood_scale: crate::types::LikelihoodScaleMetadata::Unspecified,
log_likelihood_normalization: crate::types::LogLikelihoodNormalization::UserProvided,
log_likelihood: state.log_likelihood,
deviance: state.deviance,
reml_score,
stable_penalty_term: state.penalty_term,
penalized_objective: reml_score,
outer_iterations: summary.iterations,
outer_converged: true,
outer_gradient_norm: Some(summary.lastgradient_norm),
standard_deviation: 1.0,
covariance_conditional: None,
covariance_corrected: None,
inference: Some(inference),
fitted_link: FittedLinkState::Standard(None),
geometry: None,
block_states: Vec::new(),
pirls_status: summary.status,
max_abs_eta: summary.max_abs_eta,
constraint_kkt: None,
artifacts: crate::estimate::FitArtifacts {
pirls: None,
..Default::default()
},
inner_cycles: 0,
})
.map_err(|err| err.to_string())
}
fn replicate_pooled_baseline_seed_per_cause(
pooled_seed: ArrayView1<'_, f64>,
cause_count: usize,
) -> Array1<f64> {
let p = pooled_seed.len();
let mut beta0_flat = Array1::<f64>::zeros(p * cause_count);
for cause in 0..cause_count {
beta0_flat
.slice_mut(s![cause * p..(cause + 1) * p])
.assign(&pooled_seed);
}
beta0_flat
}
fn fit_cause_specific_survival_transformation_custom(
spec: &SurvivalTransformationTermSpec,
resolvedspec: TermCollectionSpec,
baseline_cfg: crate::families::survival_construction::SurvivalBaselineConfig,
prepared: PreparedSurvivalTimeStack,
dense_cov_design: &Array2<f64>,
penalty_blocks: Vec<PenaltyBlock>,
beta0_flat: Array1<f64>,
derivative_floor: f64,
penalty_block_gamma_priors: &[(String, f64, f64)],
) -> Result<SurvivalTransformationFitResult, String> {
let cause_count = crate::survival::cause_count_from_event_codes(spec.event_target.view())
.into_workflow_result()?;
if cause_count == 0 {
return Err(WorkflowError::MissingDependency {
reason: "cause-specific custom survival fit requires at least one cause".to_string(),
}
.into());
}
let n = spec.event_target.len();
let p_time_total = prepared.time_design_exit.ncols();
let p_cov = dense_cov_design.ncols();
let p = p_time_total + p_cov;
if beta0_flat.len() != p * cause_count {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"cause-specific survival initial beta length mismatch: got {}, expected {}",
beta0_flat.len(),
p * cause_count
),
}
.into());
}
let dense_time_entry = prepared.time_design_entry.to_dense();
let dense_time_exit = prepared.time_design_exit.to_dense();
let dense_time_derivative = prepared.time_design_derivative_exit.to_dense();
let mut x_entry = Array2::<f64>::zeros((n, p));
let mut x_exit = Array2::<f64>::zeros((n, p));
let mut x_derivative = Array2::<f64>::zeros((n, p));
if p_time_total > 0 {
x_entry
.slice_mut(s![.., ..p_time_total])
.assign(&dense_time_entry);
x_exit
.slice_mut(s![.., ..p_time_total])
.assign(&dense_time_exit);
x_derivative
.slice_mut(s![.., ..p_time_total])
.assign(&dense_time_derivative);
}
if p_cov > 0 {
x_entry
.slice_mut(s![.., p_time_total..])
.assign(dense_cov_design);
x_exit
.slice_mut(s![.., p_time_total..])
.assign(dense_cov_design);
}
let mut family_blocks = Vec::with_capacity(cause_count);
let mut block_specs = Vec::with_capacity(cause_count);
for cause in 0..cause_count {
let cause_code = (cause + 1) as u8;
let event_target = spec
.event_target
.mapv(|observed| u8::from(observed == cause_code));
family_blocks.push(crate::survival::CauseSpecificRoystonParmarBlock {
age_entry: spec.age_entry.clone(),
age_exit: spec.age_exit.clone(),
event_target,
sampleweight: spec.weights.clone(),
x_entry: x_entry.clone(),
x_exit: x_exit.clone(),
x_derivative: x_derivative.clone(),
offset_eta_entry: prepared.eta_offset_entry.clone() + &spec.covariate_offset,
offset_eta_exit: prepared.eta_offset_exit.clone() + &spec.covariate_offset,
offset_derivative_exit: prepared.derivative_offset_exit.clone(),
derivative_floor,
});
let mut penalties = Vec::with_capacity(penalty_blocks.len());
let mut nullspace_dims = Vec::with_capacity(penalty_blocks.len());
let mut initial_log_lambdas = Array1::<f64>::zeros(penalty_blocks.len());
for (penalty_idx, block) in penalty_blocks.iter().enumerate() {
if block.range.end > p || block.range.start > block.range.end {
return Err(WorkflowError::SchemaMismatch {
reason: "cause-specific survival penalty range is out of bounds".to_string(),
}
.into());
}
let block_dim = block.range.end - block.range.start;
if block.matrix.nrows() != block_dim || block.matrix.ncols() != block_dim {
return Err(WorkflowError::SchemaMismatch {
reason: format!(
"cause-specific survival penalty {penalty_idx} has shape {}x{} but range has width {block_dim}",
block.matrix.nrows(),
block.matrix.ncols()
),
}
.into());
}
penalties.push(
PenaltyMatrix::Blockwise {
local: block.matrix.clone(),
col_range: block.range.clone(),
total_dim: p,
}
.with_precision_label(format!("cause_specific_survival_penalty_{penalty_idx}")),
);
nullspace_dims.push(block.nullspace_dim);
initial_log_lambdas[penalty_idx] = block.lambda.max(LOG_LAMBDA_UNDERFLOW_FLOOR).ln();
}
let beta_start = beta0_flat.slice(s![cause * p..(cause + 1) * p]).to_owned();
let cause_priority =
100u8.saturating_add(u8::try_from(cause_count - cause).unwrap_or(u8::MAX));
let cause_jacobian = std::sync::Arc::new(AdditiveBlockJacobian {
design: x_exit.clone(),
own_output: cause,
n_family_outputs: cause_count,
});
block_specs.push(ParameterBlockSpec {
name: format!("time_cause_{}", cause + 1),
design: crate::matrix::DesignMatrix::from(x_exit.clone()),
offset: prepared.eta_offset_exit.clone() + &spec.covariate_offset,
penalties,
nullspace_dims,
initial_log_lambdas,
initial_beta: Some(beta_start),
gauge_priority: cause_priority,
jacobian_callback: Some(cause_jacobian),
stacked_design: None,
stacked_offset: None,
});
}
let family = crate::survival::CauseSpecificRoystonParmarFamily::new(family_blocks)?;
let fit_options = BlockwiseFitOptions {
compute_covariance: false,
..Default::default()
};
let rho_prior =
cause_specific_survival_rho_prior(penalty_blocks.len(), penalty_block_gamma_priors)?;
let mut fit = fit_custom_family_with_rho_prior(&family, &block_specs, &fit_options, rho_prior)
.map_err(|err| format!("cause-specific survival custom-family fit failed: {err}"))?;
fit.likelihood_family = Some(LikelihoodSpec::royston_parmar());
let time_basis = crate::families::survival_construction::SavedSurvivalTimeBasis::from_build(
&spec.time_build,
spec.time_anchor,
);
let fitted_baseline_cfg = if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull
&& spec.timewiggle.is_none()
{
let first_block = fit.blocks.first().ok_or_else(|| {
"cause-specific survival fit produced no coefficient blocks".to_string()
})?;
let time_beta = first_block
.beta
.slice(s![..spec.time_build.x_exit_time.ncols()])
.to_owned();
fitted_weibull_baseline_from_linear_time_beta(&time_beta, spec.time_anchor).ok_or_else(|| {
"failed to recover fitted Weibull scale/shape from the cause-specific linear time coefficients"
.to_string()
})?
} else {
baseline_cfg
};
Ok(SurvivalTransformationFitResult {
fit,
resolvedspec,
baseline_cfg: fitted_baseline_cfg,
likelihood_mode: spec.likelihood_mode,
time_basis,
time_base_ncols: spec.time_build.x_exit_time.ncols(),
baseline_timewiggle: prepared.timewiggle_block,
})
}
fn cause_specific_survival_rho_prior(
penalty_count: usize,
penalty_block_gamma_priors: &[(String, f64, f64)],
) -> Result<crate::types::RhoPrior, String> {
if penalty_block_gamma_priors.is_empty() {
return Ok(crate::types::RhoPrior::Flat);
}
let mut keyed = BTreeMap::<String, (f64, f64)>::new();
for (label, shape, rate) in penalty_block_gamma_priors {
if keyed.insert(label.clone(), (*shape, *rate)).is_some() {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"duplicate Gamma precision hyperprior for penalty block label '{label}'"
),
}
.into());
}
if !shape.is_finite() || *shape <= 0.0 {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"Gamma precision hyperprior for penalty block '{label}' requires shape > 0, got {shape}"
),
}
.into());
}
if !rate.is_finite() || *rate < 0.0 {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"Gamma precision hyperprior for penalty block '{label}' requires rate >= 0, got {rate}"
),
}
.into());
}
}
let mut consumed = Vec::<String>::new();
let mut priors = Vec::<crate::types::RhoPrior>::with_capacity(penalty_count);
for penalty_idx in 0..penalty_count {
let label = format!("cause_specific_survival_penalty_{penalty_idx}");
if let Some((shape, rate)) = keyed.get(&label) {
consumed.push(label);
priors.push(crate::types::RhoPrior::GammaPrecision {
shape: *shape,
rate: *rate,
});
} else {
priors.push(crate::types::RhoPrior::Flat);
}
}
let unknown = keyed
.keys()
.filter(|label| !consumed.iter().any(|known| known == *label))
.cloned()
.collect::<Vec<_>>();
if !unknown.is_empty() {
let available = (0..penalty_count)
.map(|idx| format!("cause_specific_survival_penalty_{idx}"))
.collect::<Vec<_>>()
.join(", ");
return Err(WorkflowError::InvalidConfig {
reason: format!(
"unknown Gamma precision hyperprior penalty block label(s): {}; available labels: {available}",
unknown.join(", ")
),
}
.into());
}
Ok(crate::types::RhoPrior::Independent(priors))
}
fn hash_workflow_array_view(hasher: &mut crate::cache::Fingerprinter, array: ArrayView1<'_, f64>) {
hasher.write_usize(array.len());
for &value in array {
hasher.write_f64(value);
}
}
fn hash_workflow_u8_array(hasher: &mut crate::cache::Fingerprinter, array: ArrayView1<'_, u8>) {
hasher.write_usize(array.len());
for &value in array {
hasher.write_usize(usize::from(value));
}
}
fn hash_workflow_array2(hasher: &mut crate::cache::Fingerprinter, array: ArrayView2<'_, f64>) {
hasher.write_usize(array.nrows());
hasher.write_usize(array.ncols());
for row in array.rows() {
for &value in row {
hasher.write_f64(value);
}
}
}
fn hash_workflow_design_matrix(
hasher: &mut crate::cache::Fingerprinter,
matrix: &crate::matrix::DesignMatrix,
) {
let dense = matrix.to_dense();
hash_workflow_array2(hasher, dense.view());
}
fn survival_transformation_log_lambdas(
penalty_blocks: &[crate::survival::PenaltyBlock],
) -> Vec<f64> {
penalty_blocks
.iter()
.map(|block| block.lambda.max(LOG_LAMBDA_UNDERFLOW_FLOOR).ln())
.collect()
}
fn persistent_survival_transformation_key(
spec: &SurvivalTransformationTermSpec,
baseline_cfg: &crate::families::survival_construction::SurvivalBaselineConfig,
dense_cov_design: ArrayView2<'_, f64>,
prepared: &PreparedSurvivalTimeStack,
penalty_blocks: &[crate::survival::PenaltyBlock],
opts: &crate::pirls::WorkingModelPirlsOptions,
n_cols: usize,
) -> String {
let mut hasher = crate::cache::Fingerprinter::new();
hasher.write_str("gamfit-persistent-survival-transformation-working-pirls");
hasher.write_str(&crate::solver::persistent_warm_start::cache_schema_tag());
hasher.write_str(&format!("{:?}", spec.likelihood_mode));
hasher.write_f64(spec.time_anchor);
hasher.write_f64(spec.ridge_lambda);
hasher.write_str(&format!("{:?}", baseline_cfg.target));
for value in [
baseline_cfg.scale,
baseline_cfg.shape,
baseline_cfg.rate,
baseline_cfg.makeham,
] {
hasher.write_bool(value.is_some());
if let Some(value) = value {
hasher.write_f64(value);
}
}
hasher.write_str(&spec.time_build.basisname);
hasher.write_usize(spec.time_build.x_entry_time.nrows());
hasher.write_usize(spec.time_build.x_entry_time.ncols());
hasher.write_usize(spec.time_build.x_exit_time.nrows());
hasher.write_usize(spec.time_build.x_exit_time.ncols());
hasher.write_usize(spec.time_build.x_derivative_time.nrows());
hasher.write_usize(spec.time_build.x_derivative_time.ncols());
hasher.write_bool(spec.time_build.degree.is_some());
if let Some(degree) = spec.time_build.degree {
hasher.write_usize(degree);
}
match spec.time_build.knots.as_ref() {
Some(knots) => {
hasher.write_bool(true);
hasher.write_usize(knots.len());
for &knot in knots {
hasher.write_f64(knot);
}
}
None => hasher.write_bool(false),
}
match spec.time_build.keep_cols.as_ref() {
Some(cols) => {
hasher.write_bool(true);
hasher.write_usize(cols.len());
for &col in cols {
hasher.write_usize(col);
}
}
None => hasher.write_bool(false),
}
hasher.write_bool(spec.time_build.smooth_lambda.is_some());
if let Some(lambda) = spec.time_build.smooth_lambda {
hasher.write_f64(lambda);
}
hasher.write_usize(n_cols);
hash_workflow_array_view(&mut hasher, spec.age_entry.view());
hash_workflow_array_view(&mut hasher, spec.age_exit.view());
hash_workflow_u8_array(&mut hasher, spec.event_target.view());
hash_workflow_array_view(&mut hasher, spec.weights.view());
hash_workflow_array_view(&mut hasher, spec.covariate_offset.view());
hash_workflow_array2(&mut hasher, dense_cov_design);
hash_workflow_array_view(&mut hasher, prepared.eta_offset_entry.view());
hash_workflow_array_view(&mut hasher, prepared.eta_offset_exit.view());
hash_workflow_array_view(&mut hasher, prepared.derivative_offset_exit.view());
hash_workflow_design_matrix(&mut hasher, &prepared.time_design_entry);
hash_workflow_design_matrix(&mut hasher, &prepared.time_design_exit);
hash_workflow_design_matrix(&mut hasher, &prepared.time_design_derivative_exit);
hasher.write_usize(penalty_blocks.len());
for block in penalty_blocks {
hasher.write_f64(block.lambda);
hasher.write_usize(block.range.start);
hasher.write_usize(block.range.end);
hasher.write_usize(block.nullspace_dim);
hash_workflow_array2(&mut hasher, block.matrix.view());
}
hasher.write_usize(opts.max_iterations);
hasher.write_f64(opts.convergence_tolerance);
hasher.write_usize(opts.max_step_halving);
hasher.write_f64(opts.min_step_size);
hasher.write_bool(opts.firth_bias_reduction);
hasher.write_bool(opts.coefficient_lower_bounds.is_some());
if let Some(bounds) = opts.coefficient_lower_bounds.as_ref() {
hash_workflow_array_view(&mut hasher, bounds.view());
}
hasher.write_bool(opts.linear_constraints.is_some());
format!("surv-transform-{}", hasher.finish_hex())
}
fn load_survival_transformation_persistent_warm_start(
key: &str,
spec: &SurvivalTransformationTermSpec,
n_cols: usize,
rho: &[f64],
) -> Option<(Array1<f64>, Option<f64>)> {
let record = crate::solver::persistent_warm_start::load_record(key)?;
if !record.is_compatible(key, spec.age_entry.len(), n_cols)
|| record.rho.len() != rho.len()
|| !record
.rho
.iter()
.zip(rho.iter())
.all(|(cached, expected)| (*cached - *expected).abs() <= 1e-10)
{
return None;
}
log::info!("[warm-start-cache] restored survival transformation warm start key={key}");
let lm_lambda = record
.last_pirls_lm_lambda
.filter(|value| value.is_finite() && *value > 0.0);
Some((Array1::from_vec(record.beta), lm_lambda))
}
fn store_survival_transformation_persistent_warm_start(
key: &str,
spec: &SurvivalTransformationTermSpec,
n_cols: usize,
rho: Vec<f64>,
beta: &Array1<f64>,
summary: &crate::pirls::WorkingModelPirlsResult,
) {
if beta.len() != n_cols
|| beta.iter().any(|value| !value.is_finite())
|| rho.iter().any(|value| !value.is_finite())
{
return;
}
let mut record = crate::solver::persistent_warm_start::PersistentWarmStartRecord::new(
key.to_string(),
spec.age_entry.len(),
n_cols,
);
record.rho = rho;
record.beta = beta.to_vec();
record.last_inner_iters = summary.iterations;
record.last_inner_converged = matches!(
summary.status,
crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum
);
record.last_pirls_lm_lambda = (summary.final_lm_lambda.is_finite()
&& summary.final_lm_lambda > 0.0)
.then_some(summary.final_lm_lambda);
record.last_pirls_accept_rho = summary
.final_accept_rho
.filter(|value| value.is_finite() && *value >= 0.0);
if let Err(err) = crate::solver::persistent_warm_start::store_record(&record) {
log::warn!(
"[warm-start-cache] failed to persist survival transformation warm start: {err}"
);
}
}
fn fit_survival_transformation_model(
request: SurvivalTransformationFitRequest<'_>,
) -> Result<SurvivalTransformationFitResult, String> {
use crate::survival::{MonotonicityPenalty, PenaltyBlock, PenaltyBlocks, SurvivalSpec};
let SurvivalTransformationFitRequest {
data,
spec,
cache_session: _cache_session,
} = request;
let mut baseline_cfg = spec.baseline_cfg.clone();
let covariate_design =
build_term_collection_design(data, &spec.covariate_spec).map_err(|err| err.to_string())?;
let resolvedspec =
crate::smooth::freeze_term_collection_from_design(&spec.covariate_spec, &covariate_design)
.map_err(|err| err.to_string())?;
let dense_cov_design = covariate_design.design.to_dense();
let p_cov = dense_cov_design.ncols();
let cause_count = crate::survival::cause_count_from_event_codes(spec.event_target.view())
.into_workflow_result()?;
let exact_derivative_guard = survival_derivative_guard_for_likelihood(spec.likelihood_mode);
let build_working_model =
|candidate: &crate::families::survival_construction::SurvivalBaselineConfig| {
let prepared = prepare_survival_time_stack(
&spec.age_entry,
&spec.age_exit,
candidate,
spec.likelihood_mode,
None,
spec.time_anchor,
exact_derivative_guard,
&spec.time_build,
spec.timewiggle.as_ref(),
None,
)?;
let mut eta_offset_entry = prepared.eta_offset_entry.clone();
let mut eta_offset_exit = prepared.eta_offset_exit.clone();
eta_offset_entry += &spec.covariate_offset;
eta_offset_exit += &spec.covariate_offset;
let p_time_total = prepared.time_design_exit.ncols();
let p = p_time_total + p_cov;
let mut penalty_blocks = Vec::<PenaltyBlock>::new();
for (idx, penalty) in prepared.time_penalties.iter().enumerate() {
if penalty.nrows() == p_time_total && penalty.ncols() == p_time_total {
penalty_blocks.push(PenaltyBlock {
matrix: penalty.clone(),
lambda: spec.time_build.smooth_lambda.unwrap_or(1e-2),
range: 0..p_time_total,
nullspace_dim: prepared.time_nullspace_dims.get(idx).copied().unwrap_or(0),
});
}
}
for cov_penalty in &covariate_design.penalties {
let cr = &cov_penalty.col_range;
let block_dim = cr.end - cr.start;
let matches_dims = cov_penalty.local.nrows() == block_dim
&& cov_penalty.local.ncols() == block_dim;
let zero_prior = matches!(
cov_penalty.prior_mean,
crate::estimate::CoefficientPriorMean::Zero
);
if block_dim > 0 && matches_dims && zero_prior && cr.end <= p_cov {
penalty_blocks.push(PenaltyBlock {
matrix: cov_penalty.local.clone(),
lambda: 1e-2,
range: (p_time_total + cr.start)..(p_time_total + cr.end),
nullspace_dim: 0,
});
}
}
let num_smoothing_blocks = penalty_blocks.len();
let ridge_range_start = if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull
&& spec.time_build.basisname == "linear"
&& spec.timewiggle.is_none()
{
1
} else {
0
};
if spec.ridge_lambda > 0.0 && p > ridge_range_start {
let dim = p - ridge_range_start;
let mut ridge = Array2::<f64>::zeros((dim, dim));
for d in 0..dim {
ridge[[d, d]] = 1.0;
}
penalty_blocks.push(PenaltyBlock {
matrix: ridge,
lambda: spec.ridge_lambda,
range: ridge_range_start..p,
nullspace_dim: 0,
});
}
let dense_time_entry = prepared.time_design_entry.to_dense();
let dense_time_exit = prepared.time_design_exit.to_dense();
let dense_time_derivative = prepared.time_design_derivative_exit.to_dense();
let event_competing = Array1::<u8>::zeros(spec.event_target.len());
let baseline_event_indicator = spec.event_target.mapv(|label| u8::from(label > 0));
let mut model =
crate::families::royston_parmar::working_model_from_time_covariateshared(
PenaltyBlocks::new(penalty_blocks.clone()),
MonotonicityPenalty { tolerance: 0.0 },
SurvivalSpec::Net,
crate::families::royston_parmar::RoystonParmarSharedTimeCovariateInputs {
age_entry: spec.age_entry.view(),
age_exit: spec.age_exit.view(),
event_target: baseline_event_indicator.view(),
event_competing: event_competing.view(),
weights: spec.weights.view(),
time_entry: dense_time_entry.view(),
time_exit: dense_time_exit.view(),
time_derivative: dense_time_derivative.view(),
covariates: dense_cov_design.view(),
monotonicity_constraint_rows: None,
monotonicity_constraint_offsets: None,
eta_offset_entry: Some(eta_offset_entry.view()),
eta_offset_exit: Some(eta_offset_exit.view()),
derivative_offset_exit: Some(prepared.derivative_offset_exit.view()),
},
)
.map_err(|err| format!("failed to construct survival model: {err}"))?;
if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull {
model
.set_structural_monotonicity(true, p_time_total)
.map_err(|err| format!("failed to enable structural monotonicity: {err}"))?;
}
let mut beta0 = Array1::<f64>::zeros(p);
if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none()
{
let (scale, shape) = spec
.weibull_seed
.ok_or_else(|| "weibull survival fit missing scale/shape seed".to_string())?;
if p_time_total < 2 {
return Err(format!(
"weibull built-in time basis has {p_time_total} columns but needs 2 to seed scale/shape"
));
}
beta0[0] = -shape * scale.ln();
beta0[1] = shape;
}
let structural_lower_bounds =
if spec.likelihood_mode != SurvivalLikelihoodMode::Weibull && p_time_total > 0 {
let mut lb = Array1::from_elem(p, f64::NEG_INFINITY);
for j in 0..p_time_total {
lb[j] = 0.0;
beta0[j] = 1e-4;
}
Some(lb)
} else {
None
};
Ok::<_, String>((
prepared,
penalty_blocks,
beta0,
structural_lower_bounds,
model,
num_smoothing_blocks,
))
};
if baseline_cfg.target != SurvivalBaselineTarget::Linear {
baseline_cfg = optimize_survival_baseline_config_with_gradient_only(
&baseline_cfg,
"workflow survival transformation baseline",
|candidate| {
let (_, _, beta0, structural_lower_bounds, mut model, _) =
build_working_model(candidate)?;
let opts = crate::pirls::WorkingModelPirlsOptions {
max_iterations: SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS,
convergence_tolerance: SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL,
adaptive_kkt_tolerance: None,
max_step_halving: SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING,
min_step_size: SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE,
firth_bias_reduction: false,
coefficient_lower_bounds: structural_lower_bounds,
linear_constraints: None,
initial_lm_lambda: None,
geodesic_acceleration: false,
arrow_schur: None,
};
let summary = crate::pirls::runworking_model_pirls(
&mut model,
crate::types::Coefficients::new(beta0),
&opts,
|_| {},
)
.map_err(|err| format!("survival PIRLS failed: {err}"))?;
let beta = summary.beta.as_ref().to_owned();
let state = model.update_state(&beta).map_err(|err| {
format!("failed to evaluate survival baseline candidate: {err}")
})?;
let cost = survival_working_reml_score(&state);
let residuals = model.offset_channel_residuals(&beta).map_err(|err| {
format!("failed to form survival baseline offset residuals: {err}")
})?;
let gradient = baseline_chain_rule_gradient(
spec.age_entry.view(),
spec.age_exit.view(),
candidate,
&residuals,
)?
.ok_or_else(|| {
"workflow survival transformation baseline unexpectedly has no theta gradient"
.to_string()
})?;
Ok((cost, gradient))
},
)?;
}
let (
prepared,
mut penalty_blocks,
beta0,
structural_lower_bounds,
mut model,
num_smoothing_blocks,
) = build_working_model(&baseline_cfg)?;
if cause_count > 1 || !spec.penalty_block_gamma_priors.is_empty() {
let beta0_flat = replicate_pooled_baseline_seed_per_cause(beta0.view(), cause_count);
return fit_cause_specific_survival_transformation_custom(
&spec,
resolvedspec,
baseline_cfg,
prepared,
&dense_cov_design,
penalty_blocks,
beta0_flat,
exact_derivative_guard,
&spec.penalty_block_gamma_priors,
);
}
if let Some(selected_lambdas) = optimize_survival_transformation_smoothing(
&model,
&penalty_blocks,
num_smoothing_blocks,
&beta0,
structural_lower_bounds.as_ref(),
)? {
model
.set_penalty_lambdas(&selected_lambdas)
.map_err(|e| e.to_string())?;
for (block, &lam) in penalty_blocks.iter_mut().zip(selected_lambdas.iter()) {
block.lambda = lam;
}
}
let opts = crate::pirls::WorkingModelPirlsOptions {
max_iterations: SURVIVAL_TRANSFORMATION_PIRLS_MAX_ITERATIONS,
convergence_tolerance: SURVIVAL_TRANSFORMATION_PIRLS_CONVERGENCE_TOL,
adaptive_kkt_tolerance: None,
max_step_halving: SURVIVAL_TRANSFORMATION_PIRLS_MAX_STEP_HALVING,
min_step_size: SURVIVAL_TRANSFORMATION_PIRLS_MIN_STEP_SIZE,
firth_bias_reduction: false,
coefficient_lower_bounds: structural_lower_bounds,
linear_constraints: None,
initial_lm_lambda: None,
geodesic_acceleration: false,
arrow_schur: None,
};
let rho_for_cache = survival_transformation_log_lambdas(&penalty_blocks);
let persistent_warm_start_key = persistent_survival_transformation_key(
&spec,
&baseline_cfg,
dense_cov_design.view(),
&prepared,
&penalty_blocks,
&opts,
beta0.len(),
);
let mut opts = opts;
let beta_start = match load_survival_transformation_persistent_warm_start(
&persistent_warm_start_key,
&spec,
beta0.len(),
&rho_for_cache,
) {
Some((beta, lm_lambda)) => {
opts.initial_lm_lambda = lm_lambda;
beta
}
None => beta0,
};
let summary = crate::pirls::runworking_model_pirls(
&mut model,
crate::types::Coefficients::new(beta_start),
&opts,
|_| {},
)
.map_err(|err| format!("survival PIRLS failed: {err}"))?;
match summary.status {
crate::pirls::PirlsStatus::Converged | crate::pirls::PirlsStatus::StalledAtValidMinimum => {
}
ref other => {
return Err(WorkflowError::IntegrationFailed {
reason: format!(
"survival PIRLS did not converge: status={other:?}, grad_norm={:.3e}, iterations={}, deviance={:.6e}",
summary.lastgradient_norm, summary.iterations, summary.state.deviance
),
}
.into());
}
}
let beta = summary.beta.as_ref().to_owned();
store_survival_transformation_persistent_warm_start(
&persistent_warm_start_key,
&spec,
beta.len(),
rho_for_cache,
&beta,
&summary,
);
let state = model
.update_state(&beta)
.map_err(|err| format!("failed to evaluate survival optimum: {err}"))?;
let lambdas = Array1::from_iter(penalty_blocks.iter().map(|block| block.lambda));
let fitted_baseline_cfg =
if spec.likelihood_mode == SurvivalLikelihoodMode::Weibull && spec.timewiggle.is_none() {
let time_beta = beta
.slice(s![..spec.time_build.x_exit_time.ncols()])
.to_owned();
fitted_weibull_baseline_from_linear_time_beta(&time_beta, spec.time_anchor).ok_or_else(
|| {
"failed to recover fitted Weibull scale/shape from the linear time coefficients"
.to_string()
},
)?
} else {
baseline_cfg
};
let fit = survival_unified_fit_result(beta, lambdas, &summary, &state, &penalty_blocks)?;
let time_base_ncols = spec.time_build.x_exit_time.ncols();
let time_basis = crate::families::survival_construction::SavedSurvivalTimeBasis::from_build(
&spec.time_build,
spec.time_anchor,
);
Ok(SurvivalTransformationFitResult {
fit,
resolvedspec,
baseline_cfg: fitted_baseline_cfg,
likelihood_mode: spec.likelihood_mode,
time_basis,
time_base_ncols,
baseline_timewiggle: prepared.timewiggle_block,
})
}
fn fit_survival_location_scale_model(
request: SurvivalLocationScaleFitRequest<'_>,
) -> Result<SurvivalLocationScaleFitResult, String> {
fn profile_survival_location_scale(
data: ArrayView2<'_, f64>,
spec: SurvivalLocationScaleTermSpec,
wiggle: Option<LinkWiggleConfig>,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<SurvivalLocationScaleProfile, String> {
let mut wiggle_knots = None;
let mut wiggle_degree = None;
let inverse_link = spec.inverse_link.clone();
let fit = if let Some(wiggle) = wiggle {
require_inverse_link_supports_joint_wiggle(&inverse_link, "survival link wiggle")?;
let mut pilot_spec = spec.clone();
pilot_spec.linkwiggle_block = None;
let pilot = fit_survival_location_scale_terms(data, pilot_spec, kappa_options)?;
let selected_wiggle_basis = select_survival_link_wiggle_basis_from_pilot(
&pilot,
&WiggleBlockConfig {
degree: wiggle.degree,
num_internal_knots: wiggle.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle.double_penalty,
},
&wiggle.penalty_orders,
)?;
wiggle_knots = Some(selected_wiggle_basis.knots.clone());
wiggle_degree = Some(selected_wiggle_basis.degree);
fit_survival_location_scale_terms_with_selected_wiggle(
data,
spec,
selected_wiggle_basis,
kappa_options,
)?
} else {
fit_survival_location_scale_terms(data, spec, kappa_options)?
};
Ok(SurvivalLocationScaleProfile {
fit,
inverse_link,
wiggle_knots,
wiggle_degree,
})
}
fn profile_survival_location_scale_with_inverse_link(
data: ArrayView2<'_, f64>,
spec: &SurvivalLocationScaleTermSpec,
inverse_link: InverseLink,
wiggle: Option<LinkWiggleConfig>,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<SurvivalLocationScaleProfile, String> {
let mut spec_at_link = spec.clone();
spec_at_link.inverse_link = inverse_link;
profile_survival_location_scale(data, spec_at_link, wiggle, kappa_options)
}
fn optimize_survival_inverse_link_profile(
data: ArrayView2<'_, f64>,
spec: &SurvivalLocationScaleTermSpec,
wiggle: Option<LinkWiggleConfig>,
kappa_options: &SpatialLengthScaleOptimizationOptions,
) -> Result<SurvivalLocationScaleProfile, String> {
fn optimize_link_parameters<R>(
data: ArrayView2<'_, f64>,
spec: &SurvivalLocationScaleTermSpec,
kappa_options: &SpatialLengthScaleOptimizationOptions,
init: Array1<f64>,
name: &str,
final_wiggle: Option<LinkWiggleConfig>,
wiggle_cfg: Option<LinkWiggleConfig>,
make_link: impl Fn(&Array1<f64>) -> Result<InverseLink, String> + Clone,
recover: R,
) -> Result<SurvivalLocationScaleProfile, String>
where
R: Fn(&Array1<f64>) -> Option<InverseLink>,
{
use crate::solver::outer_strategy::{
DeclaredHessianForm, Derivative, HessianResult, OuterEval, OuterProblem,
};
let dim = init.len();
let lower = init.mapv(|v| v - 6.0);
let upper = init.mapv(|v| v + 6.0);
let problem = OuterProblem::new(dim)
.with_gradient(Derivative::Analytic)
.with_hessian(DeclaredHessianForm::Unavailable)
.with_tolerance(1e-4)
.with_max_iter(240)
.with_bounds(lower, upper)
.with_initial_rho(init.clone())
.with_seed_config(crate::seeding::SeedConfig {
max_seeds: 1,
seed_budget: 1,
num_auxiliary_trailing: dim,
..Default::default()
});
let context = format!("survival inverse-link optimization ({name}, dim={dim})");
let eval_link = move |theta: &Array1<f64>| -> Result<(f64, Array1<f64>), String> {
let link = make_link(theta)?;
let profile = profile_survival_location_scale_with_inverse_link(
data,
spec,
link,
wiggle_cfg.clone(),
kappa_options,
)?;
let cost = -profile.fit.fit.log_likelihood
+ 0.5 * profile.fit.fit.stable_penalty_term;
if !cost.is_finite() {
return Err(format!(
"survival inverse-link ({name}): non-finite profile cost \
(log_likelihood={}, stable_penalty_term={})",
profile.fit.fit.log_likelihood, profile.fit.fit.stable_penalty_term
));
}
let gradient = profile
.fit
.link_param_data_fit_gradient
.clone()
.ok_or_else(|| {
format!(
"survival inverse-link ({name}): fit reported no link-parameter \
data-fit gradient"
)
})?;
if gradient.len() != theta.len() {
return Err(format!(
"survival inverse-link ({name}): gradient dim {} != theta dim {}",
gradient.len(),
theta.len()
));
}
Ok((cost, gradient))
};
let cost_eval = eval_link.clone();
let cost_fn = move |_: &mut (), theta: &Array1<f64>| {
cost_eval(theta)
.map(|(cost, _)| cost)
.map_err(crate::estimate::EstimationError::InvalidInput)
};
let eval_fn = move |_: &mut (), theta: &Array1<f64>| {
let (cost, gradient) =
eval_link(theta).map_err(crate::estimate::EstimationError::InvalidInput)?;
Ok(OuterEval {
cost,
gradient,
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
})
};
let mut obj = problem.build_objective(
(),
cost_fn,
eval_fn,
None::<fn(&mut ())>,
None::<
fn(
&mut (),
&Array1<f64>,
)
-> Result<
crate::solver::outer_strategy::EfsEval,
crate::estimate::EstimationError,
>,
>,
);
let result = problem
.run(&mut obj, &context)
.map_err(|err| format!("{context} failed: {err}"))?;
let link = recover_converged_survival_inverse_link(result, &context, recover)?;
profile_survival_location_scale_with_inverse_link(
data,
spec,
link,
final_wiggle,
kappa_options,
)
.map_err(|err| format!("{context} final profiling failed: {err}"))
}
match spec.inverse_link.clone() {
InverseLink::Sas(state0) => optimize_link_parameters(
data,
spec,
kappa_options,
Array1::from_vec(vec![state0.epsilon, state0.log_delta]),
"SAS",
wiggle.clone(),
wiggle.clone(),
|theta| {
state_from_sasspec(SasLinkSpec {
initial_epsilon: theta[0],
initial_log_delta: theta[1],
})
.map(InverseLink::Sas)
},
|rho| {
state_from_sasspec(SasLinkSpec {
initial_epsilon: rho[0],
initial_log_delta: rho[1],
})
.ok()
.map(InverseLink::Sas)
},
),
InverseLink::BetaLogistic(state0) => optimize_link_parameters(
data,
spec,
kappa_options,
Array1::from_vec(vec![state0.epsilon, state0.log_delta]),
"BetaLogistic",
wiggle.clone(),
wiggle.clone(),
|theta| {
state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: theta[0],
initial_log_delta: theta[1],
})
.map(InverseLink::BetaLogistic)
},
|rho| {
state_from_beta_logisticspec(SasLinkSpec {
initial_epsilon: rho[0],
initial_log_delta: rho[1],
})
.ok()
.map(InverseLink::BetaLogistic)
},
),
InverseLink::Mixture(state0) if !state0.rho.is_empty() => {
let components = state0.components.clone();
let components_recover = components.clone();
optimize_link_parameters(
data,
spec,
kappa_options,
state0.rho.clone(),
"mixture",
wiggle.clone(),
wiggle.clone(),
move |rho| {
state_fromspec(&MixtureLinkSpec {
components: components.clone(),
initial_rho: rho.clone(),
})
.map(InverseLink::Mixture)
},
move |rho| {
state_fromspec(&MixtureLinkSpec {
components: components_recover.clone(),
initial_rho: rho.to_owned(),
})
.ok()
.map(InverseLink::Mixture)
},
)
}
_ => profile_survival_location_scale(data, spec.clone(), wiggle, kappa_options),
}
}
let profile = if request.optimize_inverse_link {
optimize_survival_inverse_link_profile(
request.data,
&request.spec,
request.wiggle.clone(),
&request.kappa_options,
)?
} else {
profile_survival_location_scale(
request.data,
request.spec.clone(),
request.wiggle.clone(),
&request.kappa_options,
)?
};
Ok(profile.into_result())
}
fn fit_bernoulli_marginal_slope_model(
request: BernoulliMarginalSlopeFitRequest<'_>,
) -> Result<BernoulliMarginalSlopeFitResult, String> {
fit_bernoulli_marginal_slope_terms(
request.data,
request.spec,
&request.options,
&request.kappa_options,
&request.policy,
)
}
fn fit_survival_marginal_slope_model(
request: SurvivalMarginalSlopeFitRequest<'_>,
) -> Result<SurvivalMarginalSlopeFitResult, String> {
fit_survival_marginal_slope_terms(
request.data,
request.spec,
&request.options,
&request.kappa_options,
)
}
fn fit_latent_survival_model(
request: LatentSurvivalFitRequest<'_>,
) -> Result<LatentSurvivalTermFitResult, String> {
fit_latent_survival_terms(
request.data,
request.spec,
request.frailty,
&request.options,
)
}
fn fit_latent_binary_model(
request: LatentBinaryFitRequest<'_>,
) -> Result<LatentBinaryTermFitResult, String> {
fit_latent_binary_terms(
request.data,
request.spec,
request.frailty,
&request.options,
)
}
fn fit_transformation_normal_model(
request: TransformationNormalFitRequest<'_>,
) -> Result<TransformationNormalFitResult, String> {
fit_transformation_normal(
&request.response,
&request.weights,
&request.offset,
request.data,
&request.covariate_spec,
&request.config,
&request.options,
&request.kappa_options,
request.warm_start.as_ref(),
)
}
pub struct CrossFitScoreCalibration {
pub z_oof: Array1<f64>,
pub jac_oof: Array2<f64>,
}
#[derive(Clone, Debug)]
pub struct CtnStage1Recipe {
pub response_column: String,
pub covariate_formula_rhs: String,
pub config: TransformationNormalConfig,
pub weight_column: Option<String>,
pub offset_column: Option<String>,
}
impl CtnStage1Recipe {
pub fn new(
response: &str,
covariates: &str,
config: TransformationNormalConfig,
weight_column: Option<&str>,
offset_column: Option<&str>,
) -> Result<Self, String> {
let response_column = response.trim().to_string();
if response_column.is_empty() {
return Err("CtnStage1Recipe requires a non-empty Stage-1 response column".to_string());
}
let covariate_formula_rhs = covariates.trim().to_string();
if covariate_formula_rhs.is_empty() {
return Err(
"CtnStage1Recipe requires a non-empty Stage-1 covariate formula RHS".to_string(),
);
}
if covariate_formula_rhs.contains('~') {
return Err(
"CtnStage1Recipe covariates is a right-hand side only; pass 's(pc1) + s(pc2)', \
not 'score ~ s(pc1) + s(pc2)'"
.to_string(),
);
}
Ok(Self {
response_column,
covariate_formula_rhs,
config,
weight_column: weight_column
.map(str::to_string)
.filter(|s| !s.trim().is_empty()),
offset_column: offset_column
.map(str::to_string)
.filter(|s| !s.trim().is_empty()),
})
}
}
fn crossfit_fold_count(n: usize) -> usize {
if n < 250 {
n.min(3).max(2)
} else if n < 200_000 {
5
} else if n < 2_000_000 {
3
} else {
2
}
}
fn crossfit_partition(n: usize, k: usize) -> Vec<Vec<usize>> {
let mut folds: Vec<Vec<usize>> = Vec::with_capacity(k);
let base = n / k;
let remainder = n % k;
let mut start = 0usize;
for f in 0..k {
let len = base + usize::from(f < remainder);
let end = start + len;
folds.push((start..end).collect());
start = end;
}
folds
}
fn crossfit_select_rows_1d(source: &Array1<f64>, indices: &[usize]) -> Array1<f64> {
Array1::from_iter(indices.iter().map(|&i| source[i]))
}
fn crossfit_score_calibration(
data: &Dataset,
col_map: &HashMap<String, usize>,
recipe: Option<&CtnStage1Recipe>,
policy: &crate::resource::ResourcePolicy,
) -> Result<Option<CrossFitScoreCalibration>, String> {
let Some(recipe) = recipe else {
return Ok(None);
};
let n = data.values.nrows();
if n == 0 {
return Err("cross-fit score calibration requires a non-empty dataset".to_string());
}
let y_col = resolve_role_col(col_map, &recipe.response_column, "response")
.map_err(|e| e.to_string())?;
let response_full = data.values.column(y_col).to_owned();
let weights_full = resolve_weight_column(data, col_map, recipe.weight_column.as_deref())
.map_err(|e| e.to_string())?;
let offset_full = resolve_offset_column(data, col_map, recipe.offset_column.as_deref())
.map_err(|e| e.to_string())?;
let parsed_cov = parse_formula(&format!(
"{} ~ {}",
recipe.response_column, recipe.covariate_formula_rhs
))
.map_err(|e| e.to_string())?;
let mut frozen_notes = Vec::new();
let covariate_spec_raw = build_termspec_with_geometry_and_overrides(
&parsed_cov.terms,
data,
col_map,
&mut frozen_notes,
false,
policy,
None,
)
.map_err(|e| e.to_string())?;
let full_cov_design = build_term_collection_design(data.values.view(), &covariate_spec_raw)
.map_err(|e| e.to_string())?;
let frozen_cov_spec =
crate::smooth::freeze_term_collection_from_design(&covariate_spec_raw, &full_cov_design)
.map_err(|e| e.to_string())?;
let p_cov = full_cov_design.design.ncols();
let k = crossfit_fold_count(n);
let folds = crossfit_partition(n, k);
let min_complement = folds.iter().map(|held| n - held.len()).min().unwrap_or(n);
let mut fold_config = recipe.config.clone();
fold_config.response_num_internal_knots =
crate::families::transformation_normal::effective_response_num_internal_knots(
&recipe.config,
min_complement,
p_cov,
response_full.view(),
);
fold_config.response_num_internal_knots_pinned = true;
let mut z_oof = Array1::<f64>::zeros(n);
let mut jac_oof: Option<Array2<f64>> = None;
for held in &folds {
if held.is_empty() {
continue;
}
let held_set: std::collections::HashSet<usize> = held.iter().copied().collect();
let complement: Vec<usize> = (0..n).filter(|i| !held_set.contains(i)).collect();
if complement.is_empty() {
return Err(
"cross-fit fold left an empty training complement; too few rows for K folds"
.to_string(),
);
}
let train_cov = data.values.select(Axis(0), &complement);
let train_resp = crossfit_select_rows_1d(&response_full, &complement);
let train_weights = crossfit_select_rows_1d(&weights_full, &complement);
let train_offset = crossfit_select_rows_1d(&offset_full, &complement);
let fold_fit = fit_transformation_normal(
&train_resp,
&train_weights,
&train_offset,
train_cov.view(),
&frozen_cov_spec,
&fold_config,
&BlockwiseFitOptions::default(),
&SpatialLengthScaleOptimizationOptions::default(),
None,
)?;
let held_cov = data.values.select(Axis(0), held);
let held_resp = crossfit_select_rows_1d(&response_full, held);
let held_offset = crossfit_select_rows_1d(&offset_full, held);
let jac = crate::families::marginal_slope_orthogonal::score_influence_jacobian(
&fold_fit,
&held_resp,
held_cov.view(),
&held_offset,
)?;
if jac.columns.nrows() != held.len() {
return Err(format!(
"cross-fit fold Jacobian row count {} != held-out fold size {}",
jac.columns.nrows(),
held.len()
));
}
if jac.z.len() != held.len() {
return Err(format!(
"cross-fit fold OOF z length {} != held-out fold size {}",
jac.z.len(),
held.len()
));
}
let p1 = jac.columns.ncols();
let jac_full = jac_oof.get_or_insert_with(|| Array2::<f64>::zeros((n, p1)));
if jac_full.ncols() != p1 {
return Err(format!(
"cross-fit fold p₁ mismatch: this fold has {p1} columns but a prior fold had {}; \
the frozen response/covariate basis failed to align across folds",
jac_full.ncols()
));
}
for (local, &global) in held.iter().enumerate() {
z_oof[global] = jac.z[local];
for c in 0..p1 {
jac_full[[global, c]] = jac.columns[[local, c]];
}
}
}
let jac_oof = jac_oof.ok_or_else(|| {
"cross-fit produced no folds with held-out rows; cannot assemble OOF Jacobian".to_string()
})?;
Ok(Some(CrossFitScoreCalibration { z_oof, jac_oof }))
}
pub fn fit_model(request: FitRequest<'_>) -> Result<FitResult, WorkflowError> {
let mut request = request;
let exact_key = request.cache_key();
let seed_key = request.cache_seed_key();
if let Some(session) = crate::solver::persistent_warm_start::open_outer_session(&exact_key) {
let exact_present = session.peek_load().is_some();
if !exact_present
&& let Some(seed) =
crate::solver::persistent_warm_start::lookup_outer_iterate_payload(&seed_key)
{
let prior_obj = seed.objective.unwrap_or(f64::NAN);
log::info!(
"[CACHE] seed key={}.. via prefix family={} prior_obj={:.6e}",
&exact_key[..8.min(exact_key.len())],
request.family_tag(),
prior_obj,
);
session.preload(seed);
}
request.attach_cache_session(session);
}
let mirror_session = crate::solver::persistent_warm_start::open_outer_session(&seed_key);
if let Some(mirror) = mirror_session.as_ref() {
request.attach_cache_mirror(Arc::clone(mirror));
}
let wrap_solver_err =
|reason: String| -> WorkflowError { WorkflowError::IntegrationFailed { reason } };
match request {
FitRequest::Standard(request) => fit_standard_model(request)
.map(FitResult::Standard)
.map_err(wrap_solver_err),
FitRequest::GaussianLocationScale(request) => fit_gaussian_location_scale_model(request)
.map(FitResult::GaussianLocationScale)
.map_err(wrap_solver_err),
FitRequest::BinomialLocationScale(request) => fit_binomial_location_scale_model(request)
.map(FitResult::BinomialLocationScale)
.map_err(wrap_solver_err),
FitRequest::DispersionLocationScale(request) => {
fit_dispersion_location_scale_model(request)
.map(FitResult::DispersionLocationScale)
.map_err(wrap_solver_err)
}
FitRequest::SurvivalLocationScale(request) => {
match fit_survival_location_scale_model(request).map(FitResult::SurvivalLocationScale) {
Ok(fit) => Ok(fit),
Err(e)
if e.contains("expects 3 blocks, got 0")
|| e.contains("expects 4 blocks, got 0")
|| (e.contains("block_states") && e.contains("got 0"))
|| e.contains("blockwise fit requires at least one block state") =>
{
Err(WorkflowError::IntegrationFailed {
reason: format!(
"survival location-scale fit failed: the smoothing-parameter optimizer \
landed at a degenerate iterate where the inner solver's block state \
was empty. This is the symptom of an under-identified smooth driven \
to a numerically pathological λ (e.g. exp(20+)) on a small-data \
subsample. Try: (1) reducing covariate count, (2) increasing n_train, \
(3) `baseline_target=\"linear\"` to drop the parametric baseline, or \
(4) `noise_formula=\"1\"` to drop the noise GAM. Underlying error: {e}"
),
})
}
Err(reason) => Err(wrap_solver_err(reason)),
}
}
FitRequest::SurvivalTransformation(request) => fit_survival_transformation_model(request)
.map(FitResult::SurvivalTransformation)
.map_err(wrap_solver_err),
FitRequest::BernoulliMarginalSlope(request) => fit_bernoulli_marginal_slope_model(request)
.map(FitResult::BernoulliMarginalSlope)
.map_err(wrap_solver_err),
FitRequest::SurvivalMarginalSlope(request) => fit_survival_marginal_slope_model(request)
.map(FitResult::SurvivalMarginalSlope)
.map_err(wrap_solver_err),
FitRequest::LatentSurvival(request) => fit_latent_survival_model(request)
.map(FitResult::LatentSurvival)
.map_err(wrap_solver_err),
FitRequest::LatentBinary(request) => fit_latent_binary_model(request)
.map(FitResult::LatentBinary)
.map_err(wrap_solver_err),
FitRequest::TransformationNormal(request) => fit_transformation_normal_model(request)
.map(FitResult::TransformationNormal)
.map_err(wrap_solver_err),
}
}
use crate::families::survival_construction::{
SurvivalBaselineTarget, SurvivalLikelihoodMode, SurvivalTimeBasisConfig,
add_survival_time_derivative_guard_offset, append_zero_tail_columns,
baseline_chain_rule_gradient, build_latent_survival_baseline_offsets,
build_survival_time_basis, build_survival_time_offsets_for_likelihood,
build_survival_timewiggle_from_baseline, build_time_varying_survival_covariate_template,
center_survival_time_designs_at_anchor, evaluate_survival_time_basis_row,
initial_survival_baseline_config_for_fit, location_scale_uses_probit_survival_baseline,
marginal_slope_baseline_chain_rule_gradient, marginal_slope_baseline_chain_rule_hessian,
normalize_survival_time_pair, optimize_survival_baseline_config_with_gradient,
optimize_survival_baseline_config_with_gradient_only, parse_survival_distribution,
parse_survival_likelihood_mode, parse_survival_time_basis_config, positive_survival_time_seed,
require_structural_survival_time_basis, resolve_survival_marginal_slope_time_anchor_value,
resolve_survival_time_anchor_value, resolved_survival_time_basis_config_from_build,
survival_derivative_guard_for_likelihood,
};
use crate::families::survival_location_scale::{
SURVIVAL_LOCATION_SCALE_EMPTY_BLOCK_STATES_MARKER, SurvivalCovariateTermBlockTemplate,
TimeBlockInput, TimeWiggleBlockInput, residual_distribution_inverse_link,
};
use crate::inference::data::EncodedDataset as Dataset;
use crate::inference::formula_dsl::{
LinkChoice, LinkWiggleFormulaSpec, ParsedFormula, ParsedTerm, effectivelinkwiggle_formulaspec,
marginal_slope_logslope_surfaces, parse_formula, parse_link_choice,
parse_matching_auxiliary_formula, parse_surv_interval_response, parse_surv_response,
require_inverse_link_supports_joint_wiggle, validate_marginal_slope_z_column_exclusion,
};
use crate::term_builder::{
SECONDARY_CENTER_CAP_OPTION, build_termspec, column_map_with_alias, enable_scale_dimensions,
has_explicit_countwith_basis_alias, resolve_role_col, resolve_smooth_type_name,
smooth_type_uses_spatial_center_heuristic,
};
#[derive(Clone, Debug)]
pub struct FitConfig {
pub family: Option<String>,
pub negative_binomial_theta: Option<f64>,
pub link: Option<String>,
pub flexible_link: bool,
pub offset_column: Option<String>,
pub noise_offset_column: Option<String>,
pub frailty: Option<FrailtySpec>,
pub baseline_target: String,
pub baseline_scale: Option<f64>,
pub baseline_shape: Option<f64>,
pub baseline_rate: Option<f64>,
pub baseline_makeham: Option<f64>,
pub time_basis: String,
pub time_degree: usize,
pub time_num_internal_knots: usize,
pub time_smooth_lambda: f64,
pub survival_likelihood: String,
pub survival_distribution: String,
pub threshold_time_k: Option<usize>,
pub threshold_time_degree: usize,
pub sigma_time_k: Option<usize>,
pub sigma_time_degree: usize,
pub noise_formula: Option<String>,
pub logslope_formula: Option<String>,
pub z_column: Option<String>,
pub weight_column: Option<String>,
pub expectile_tau: Option<f64>,
pub ctn_stage1: Option<CtnStage1Recipe>,
pub scale_dimensions: bool,
pub adaptive_regularization: Option<bool>,
pub ridge_lambda: f64,
pub transformation_normal: bool,
pub firth: bool,
pub outer_max_iter: Option<usize>,
pub gpu_policy: crate::gpu::GpuPolicy,
pub resource_policy: Option<crate::resource::ResourcePolicy>,
pub group_metadata: Option<BTreeMap<String, JsonValue>>,
pub coefficient_groups: Vec<CoefficientGroupSpec>,
pub penalty_block_gamma_priors: Vec<(String, f64, f64)>,
pub latents: Option<JsonValue>,
pub analytic_penalties: Option<JsonValue>,
pub topology_auto_selector: Option<crate::solver::topology_selector::TopologyAutoSelector>,
pub smooth_overrides: Option<JsonValue>,
pub persist_warm_start_disk: bool,
}
impl Default for FitConfig {
fn default() -> Self {
Self {
family: None,
negative_binomial_theta: None,
link: None,
flexible_link: false,
offset_column: None,
noise_offset_column: None,
frailty: None,
baseline_target: "linear".into(),
baseline_scale: None,
baseline_shape: None,
baseline_rate: None,
baseline_makeham: None,
time_basis: "ispline".into(),
time_degree: 3,
time_num_internal_knots: 8,
time_smooth_lambda: 1e-2,
survival_likelihood: "location-scale".into(),
survival_distribution: "gaussian".into(),
threshold_time_k: None,
threshold_time_degree: 3,
sigma_time_k: None,
sigma_time_degree: 3,
noise_formula: None,
logslope_formula: None,
z_column: None,
weight_column: None,
expectile_tau: None,
ctn_stage1: None,
scale_dimensions: false,
adaptive_regularization: None,
ridge_lambda: 1e-6,
transformation_normal: false,
firth: false,
outer_max_iter: None,
gpu_policy: crate::gpu::GpuPolicy::Auto,
resource_policy: None,
group_metadata: None,
coefficient_groups: Vec::new(),
penalty_block_gamma_priors: Vec::new(),
latents: None,
analytic_penalties: None,
topology_auto_selector: None,
smooth_overrides: None,
persist_warm_start_disk: false,
}
}
}
pub(crate) fn resolved_resource_policy(
config: &FitConfig,
data: &Dataset,
hints: crate::resource::ProblemHints,
) -> crate::resource::ResourcePolicy {
if let Some(p) = config.resource_policy.clone() {
return p;
}
crate::resource::ResourcePolicy::for_problem(data.values.nrows(), 0, hints)
}
fn marginal_slope_hints(config: &FitConfig) -> crate::resource::ProblemHints {
crate::resource::ProblemHints {
marginal_slope_large_scale_active: config.logslope_formula.is_some()
|| config.z_column.is_some(),
}
}
pub struct MaterializedModel<'a> {
pub request: FitRequest<'a>,
pub inference_notes: Vec<String>,
}
fn expectile_tau_for_config(config: &FitConfig) -> Result<Option<f64>, WorkflowError> {
let Some(raw) = config.family.as_deref() else {
return Ok(None);
};
let trimmed = raw.trim();
let lower = trimmed.to_ascii_lowercase();
if !(lower == "expectile" || lower.starts_with("expectile(")) {
return Ok(None);
}
let invalid = |reason: String| WorkflowError::InvalidConfig { reason };
let inline_tau = if let Some(rest) = lower.strip_prefix("expectile(") {
let inner = rest.strip_suffix(')').ok_or_else(|| {
invalid(format!(
"expectile family asymmetry must be written as `expectile(τ)`; got `{trimmed}`"
))
})?;
let value: f64 = inner.trim().parse().map_err(|_| {
invalid(format!(
"expectile asymmetry `{}` is not a finite number",
inner.trim()
))
})?;
Some(value)
} else {
None
};
let tau = match (inline_tau, config.expectile_tau) {
(Some(a), Some(b)) if (a - b).abs() > 0.0 => {
return Err(invalid(format!(
"expectile asymmetry given both inline (`expectile({a})`) and via expectile_tau \
({b}); supply exactly one"
)));
}
(Some(a), _) => a,
(None, Some(b)) => b,
(None, None) => 0.5,
};
if !(tau.is_finite() && tau > 0.0 && tau < 1.0) {
return Err(invalid(format!(
"expectile asymmetry τ must be finite and strictly in (0, 1); got {tau}"
)));
}
Ok(Some(tau))
}
fn expectile_row_weights(
y: ArrayView1<f64>,
mu: ArrayView1<f64>,
base: ArrayView1<f64>,
tau: f64,
) -> Array1<f64> {
Array1::from_shape_fn(y.len(), |i| {
let asym = if y[i] > mu[i] { tau } else { 1.0 - tau };
base[i] * asym
})
}
pub fn fit_from_formula(
formula: &str,
data: &Dataset,
config: &FitConfig,
) -> Result<FitResult, WorkflowError> {
if let Some(tau) = expectile_tau_for_config(config)? {
return fit_expectile_laws(formula, data, config, tau);
}
let mat = materialize(formula, data, config)?;
if let FitRequest::Standard(request) = &mat.request {
if let Some(inputs) = spline_scan_fast_path(request) {
let scan = crate::solver::spline_scan::fit_spline_scan(
&inputs.x,
&inputs.y,
&inputs.w,
inputs.order,
)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })?;
return Ok(FitResult::SplineScan(scan));
}
if let Some(inputs) = residual_cascade_fast_path(request) {
let coord_refs: Vec<&[f64]> = inputs.coords.iter().map(Vec::as_slice).collect();
if let Ok(fit) = crate::solver::residual_cascade::fit_residual_cascade(
&coord_refs,
&inputs.y,
&inputs.w,
&inputs.metric,
inputs.sobolev_s,
) {
return Ok(FitResult::ResidualCascade(fit));
}
}
}
fit_model(mat.request)
}
fn fit_expectile_laws(
formula: &str,
data: &Dataset,
config: &FitConfig,
tau: f64,
) -> Result<FitResult, WorkflowError> {
use crate::linalg::matrix::LinearOperator;
let gaussian_config = FitConfig {
family: Some("gaussian".to_string()),
link: Some("identity".to_string()),
expectile_tau: None,
..config.clone()
};
let base_mat = materialize(formula, data, &gaussian_config)?;
let FitRequest::Standard(base_request) = base_mat.request else {
return Err(WorkflowError::InvalidConfig {
reason: "expectile regression is only defined for standard (non-survival, \
non-location-scale) responses"
.to_string(),
});
};
let StandardFitRequest {
data: design_data,
y,
weights: base_weights,
offset,
spec,
family: materialized_family,
options,
kappa_options,
wiggle,
coefficient_groups,
penalty_block_gamma_priors,
latent_coord,
_marker,
} = base_request;
if !materialized_family.is_gaussian_identity() {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"expectile LAWS requires a Gaussian-identity inner family; materializer produced {}",
materialized_family.name()
),
});
}
if wiggle.is_some() || latent_coord.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "expectile regression does not support flexible-link wiggle or latent \
coordinates"
.to_string(),
});
}
let n = y.len();
let gaussian_family = LikelihoodSpec::gaussian_identity();
let mut weights = base_weights.clone();
let mut last_sign: Option<Vec<bool>> = None;
let mut last_result: Option<StandardFitResult> = None;
const MAX_LAWS_ITERS: usize = 50;
for _iter in 0..MAX_LAWS_ITERS {
let request = StandardFitRequest {
data: design_data.clone(),
y: y.clone(),
weights: weights.clone(),
offset: offset.clone(),
spec: spec.clone(),
family: gaussian_family.clone(),
options: options.clone(),
kappa_options: kappa_options.clone(),
wiggle: None,
coefficient_groups: coefficient_groups.clone(),
penalty_block_gamma_priors: penalty_block_gamma_priors.clone(),
latent_coord: None,
_marker,
};
let result = fit_standard_model(request)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })?;
let mu = result.design.design.apply(&result.fit.beta);
if mu.len() != n {
return Err(WorkflowError::IntegrationFailed {
reason: format!(
"expectile LAWS: fitted mean length {} disagrees with response length {n}",
mu.len()
),
});
}
let mut mu_off = mu;
mu_off += &offset;
let sign: Vec<bool> = (0..n).map(|i| y[i] > mu_off[i]).collect();
let converged = last_sign.as_ref().is_some_and(|prev| prev == &sign);
weights = expectile_row_weights(y.view(), mu_off.view(), base_weights.view(), tau);
last_sign = Some(sign);
last_result = Some(result);
if converged {
break;
}
}
let result = last_result.ok_or_else(|| WorkflowError::IntegrationFailed {
reason: "expectile LAWS produced no fit".to_string(),
})?;
Ok(FitResult::Standard(result))
}
pub struct SplineScanInputs {
pub x: Vec<f64>,
pub y: Vec<f64>,
pub w: Vec<f64>,
pub order: usize,
}
pub fn spline_scan_fast_path(request: &StandardFitRequest<'_>) -> Option<SplineScanInputs> {
if !request.family.is_gaussian_identity() {
return None;
}
if request.wiggle.is_some()
|| request.latent_coord.is_some()
|| !request.coefficient_groups.is_empty()
|| !request.penalty_block_gamma_priors.is_empty()
{
return None;
}
let options = &request.options;
if options.latent_cloglog.is_some()
|| options.mixture_link.is_some()
|| options.sas_link.is_some()
|| options.linear_constraints.is_some()
|| options.adaptive_regularization.is_some()
|| options.kronecker_penalty_system.is_some()
|| options.kronecker_factored.is_some()
|| options.firth_bias_reduction
|| !options.nullspace_dims.is_empty()
{
return None;
}
let spec = &request.spec;
if !spec.linear_terms.is_empty()
|| !spec.random_effect_terms.is_empty()
|| spec.smooth_terms.len() != 1
{
return None;
}
let term = &spec.smooth_terms[0];
if !matches!(term.shape, crate::smooth::ShapeConstraint::None)
|| term.joint_null_rotation.is_some()
{
return None;
}
let crate::smooth::SmoothBasisSpec::BSpline1D {
feature_col,
spec: bspec,
} = &term.basis
else {
return None;
};
let order = bspec.penalty_order;
if !(1..=3).contains(&order)
|| bspec.degree != 2 * order - 1
|| bspec.double_penalty
|| !bspec.boundary_conditions.is_free()
|| !matches!(bspec.boundary, crate::basis::OneDimensionalBoundary::Open)
|| matches!(
bspec.knotspec,
crate::basis::BSplineKnotSpec::PeriodicUniform { .. }
)
{
return None;
}
if request.offset.iter().any(|&v| v != 0.0) {
return None;
}
if request.weights.iter().any(|&v| !(v.is_finite() && v > 0.0)) {
return None;
}
if *feature_col >= request.data.ncols() || request.y.len() != request.data.nrows() {
return None;
}
let x: Vec<f64> = request.data.column(*feature_col).iter().copied().collect();
let y: Vec<f64> = request.y.iter().copied().collect();
let w: Vec<f64> = request.weights.iter().copied().collect();
if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
return None;
}
let mut sorted = x.clone();
sorted.sort_by(f64::total_cmp);
sorted.dedup();
if sorted.len() < order + 1 {
return None;
}
Some(SplineScanInputs { x, y, w, order })
}
pub fn fit_spline_scan_from_formula(
formula: &str,
data: &Dataset,
config: &FitConfig,
) -> Result<Option<crate::solver::spline_scan::SplineScanFit>, WorkflowError> {
let mat = materialize(formula, data, config)?;
let FitRequest::Standard(request) = mat.request else {
return Ok(None);
};
let Some(inputs) = spline_scan_fast_path(&request) else {
return Ok(None);
};
crate::solver::spline_scan::fit_spline_scan(&inputs.x, &inputs.y, &inputs.w, inputs.order)
.map(Some)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })
}
pub struct ResidualCascadeInputs {
pub coords: Vec<Vec<f64>>,
pub y: Vec<f64>,
pub w: Vec<f64>,
pub metric: Vec<f64>,
pub sobolev_s: f64,
}
fn past_dense_kernel_cliff(n: usize, d: usize) -> bool {
const DENSE_CENTER_CAP: usize = 2000;
crate::terms::basis::default_num_centers(n, d) >= DENSE_CENTER_CAP
}
fn cascade_sobolev_order(requested: f64, d: usize) -> f64 {
let lo = d as f64 / 2.0;
let hi = (d as f64 + 3.0) / 2.0;
let eps = 1e-6 * (hi - lo);
requested.clamp(lo + eps, hi)
}
pub fn residual_cascade_fast_path(
request: &StandardFitRequest<'_>,
) -> Option<ResidualCascadeInputs> {
if !request.family.is_gaussian_identity() {
return None;
}
if request.wiggle.is_some()
|| request.latent_coord.is_some()
|| !request.coefficient_groups.is_empty()
|| !request.penalty_block_gamma_priors.is_empty()
{
return None;
}
let options = &request.options;
if options.latent_cloglog.is_some()
|| options.mixture_link.is_some()
|| options.sas_link.is_some()
|| options.linear_constraints.is_some()
|| options.adaptive_regularization.is_some()
|| options.kronecker_penalty_system.is_some()
|| options.kronecker_factored.is_some()
|| options.firth_bias_reduction
|| !options.nullspace_dims.is_empty()
{
return None;
}
let spec = &request.spec;
if !spec.linear_terms.is_empty()
|| !spec.random_effect_terms.is_empty()
|| spec.smooth_terms.len() != 1
{
return None;
}
let term = &spec.smooth_terms[0];
if !matches!(term.shape, crate::smooth::ShapeConstraint::None)
|| term.joint_null_rotation.is_some()
{
return None;
}
let (feature_cols, requested_s) = match &term.basis {
crate::smooth::SmoothBasisSpec::Duchon {
feature_cols, spec, ..
} => {
let p = match spec.nullspace_order {
crate::basis::DuchonNullspaceOrder::Zero => 0.0,
crate::basis::DuchonNullspaceOrder::Linear => 1.0,
crate::basis::DuchonNullspaceOrder::Degree(k) => k as f64,
};
(feature_cols, spec.power + p)
}
crate::smooth::SmoothBasisSpec::Matern {
feature_cols, spec, ..
} => {
let nu = spec.nu.half_integer_value();
(feature_cols, nu + feature_cols.len() as f64 / 2.0)
}
_ => return None,
};
let d = feature_cols.len();
if !(2..=3).contains(&d) {
return None;
}
if request.offset.iter().any(|&v| v != 0.0) {
return None;
}
if request.weights.iter().any(|&v| !(v.is_finite() && v > 0.0)) {
return None;
}
let n = request.y.len();
if n != request.data.nrows() || feature_cols.iter().any(|&c| c >= request.data.ncols()) {
return None;
}
if !past_dense_kernel_cliff(n, d) {
return None;
}
let coords: Vec<Vec<f64>> = feature_cols
.iter()
.map(|&c| request.data.column(c).iter().copied().collect())
.collect();
let y: Vec<f64> = request.y.iter().copied().collect();
let w: Vec<f64> = request.weights.iter().copied().collect();
if coords
.iter()
.any(|axis| axis.iter().any(|v| !v.is_finite()))
|| y.iter().any(|v| !v.is_finite())
{
return None;
}
let metric = vec![1.0_f64; d];
let sobolev_s = cascade_sobolev_order(requested_s, d);
Some(ResidualCascadeInputs {
coords,
y,
w,
metric,
sobolev_s,
})
}
pub fn fit_residual_cascade_from_formula(
formula: &str,
data: &Dataset,
config: &FitConfig,
) -> Result<Option<crate::solver::residual_cascade::ResidualCascadeFit>, WorkflowError> {
let mat = materialize(formula, data, config)?;
let FitRequest::Standard(request) = mat.request else {
return Ok(None);
};
let Some(inputs) = residual_cascade_fast_path(&request) else {
return Ok(None);
};
let coord_refs: Vec<&[f64]> = inputs.coords.iter().map(Vec::as_slice).collect();
match crate::solver::residual_cascade::fit_residual_cascade(
&coord_refs,
&inputs.y,
&inputs.w,
&inputs.metric,
inputs.sobolev_s,
) {
Ok(fit) => Ok(Some(fit)),
Err(_) => Ok(None),
}
}
pub fn materialize<'a>(
formula: &str,
data: &'a Dataset,
config: &FitConfig,
) -> Result<MaterializedModel<'a>, WorkflowError> {
crate::gpu::configure_global_policy(config.gpu_policy);
let parsed = parse_formula(formula)?;
let col_map = data.column_map();
if let Some((left_col, right_col, event_col)) = parse_surv_interval_response(&parsed.response)? {
if config.transformation_normal {
return Err(WorkflowError::InvalidConfig {
reason: "transformation_normal cannot be combined with a SurvInterval(...) response"
.to_string(),
});
}
materialize_survival(
&parsed,
data,
&col_map,
config,
None,
&left_col,
&event_col,
Some(&right_col),
)
} else if let Some((entry_col, exit_col, event_col)) = parse_surv_response(&parsed.response)? {
if config.transformation_normal {
return Err(WorkflowError::InvalidConfig {
reason: "transformation_normal cannot be combined with a Surv(...) response"
.to_string(),
});
}
materialize_survival(
&parsed,
data,
&col_map,
config,
entry_col.as_deref(),
&exit_col,
&event_col,
None,
)
} else {
reject_survival_only_terms_for_nonsurvival(&parsed)?;
if config.transformation_normal {
reject_marginal_slope_controls_for_transformation_normal(config)?;
if config.noise_formula.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "transformation_normal cannot be combined with noise_formula"
.to_string(),
});
}
materialize_transformation_normal(&parsed, data, &col_map, config)
} else if config.logslope_formula.is_some() || config.z_column.is_some() {
materialize_bernoulli_marginal_slope(&parsed, data, &col_map, config)
} else if config.noise_formula.is_some() {
materialize_location_scale(&parsed, data, &col_map, config)
} else {
materialize_standard(&parsed, data, &col_map, config)
}
}
}