use super::*;
#[derive(Clone, Debug)]
pub struct LinkWiggleConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
}
#[derive(Clone)]
pub struct StandardBinomialWiggleConfig {
pub link_kind: InverseLink,
pub wiggle: LinkWiggleConfig,
pub refit_options: BlockwiseFitOptions,
}
pub struct StandardFitRequest<'a> {
pub data: Array2<f64>,
pub y: Array1<f64>,
pub weights: Array1<f64>,
pub offset: Array1<f64>,
pub spec: TermCollectionSpec,
pub family: LikelihoodSpec,
pub options: FitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub wiggle: Option<StandardBinomialWiggleConfig>,
pub coefficient_groups: Vec<CoefficientGroupSpec>,
pub penalty_block_gamma_priors: Vec<(String, f64, f64)>,
pub latent_coord: Option<StandardLatentCoordConfig>,
#[doc(hidden)]
pub _marker: std::marker::PhantomData<&'a ()>,
}
pub struct GaussianLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: GaussianLocationScaleTermSpec,
pub wiggle: Option<LinkWiggleConfig>,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
pub struct BinomialLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: BinomialLocationScaleTermSpec,
pub wiggle: Option<LinkWiggleConfig>,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
pub struct DispersionLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: DispersionGlmLocationScaleTermSpec,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
pub struct SurvivalLocationScaleFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: SurvivalLocationScaleTermSpec,
pub wiggle: Option<LinkWiggleConfig>,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub optimize_inverse_link: bool,
pub cache_session: Option<std::sync::Arc<crate::warm_start::Session>>,
}
pub struct SurvivalTransformationFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: SurvivalTransformationTermSpec,
pub cache_session: Option<std::sync::Arc<crate::warm_start::Session>>,
}
#[derive(Clone)]
pub struct SurvivalTransformationTermSpec {
pub age_entry: Array1<f64>,
pub age_exit: Array1<f64>,
pub event_target: Array1<u8>,
pub weights: Array1<f64>,
pub covariate_spec: TermCollectionSpec,
pub covariate_offset: Array1<f64>,
pub baseline_cfg: crate::families::survival::SurvivalBaselineConfig,
pub likelihood_mode: crate::families::survival::SurvivalLikelihoodMode,
pub time_anchor: f64,
pub time_build: crate::families::survival::SurvivalTimeBuildOutput,
pub timewiggle: Option<LinkWiggleFormulaSpec>,
pub weibull_seed: Option<(f64, f64)>,
pub ridge_lambda: f64,
pub penalty_block_gamma_priors: Vec<(String, f64, f64)>,
}
pub struct BernoulliMarginalSlopeFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: BernoulliMarginalSlopeTermSpec,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub policy: crate::resource::ResourcePolicy,
}
pub struct SurvivalMarginalSlopeFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: SurvivalMarginalSlopeTermSpec,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
}
pub struct LatentSurvivalFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: LatentSurvivalTermSpec,
pub frailty: FrailtySpec,
pub options: BlockwiseFitOptions,
}
pub struct LatentBinaryFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub spec: LatentBinaryTermSpec,
pub frailty: FrailtySpec,
pub options: BlockwiseFitOptions,
}
pub struct TransformationNormalFitRequest<'a> {
pub data: ArrayView2<'a, f64>,
pub response: Array1<f64>,
pub weights: Array1<f64>,
pub offset: Array1<f64>,
pub covariate_spec: TermCollectionSpec,
pub config: TransformationNormalConfig,
pub options: BlockwiseFitOptions,
pub kappa_options: SpatialLengthScaleOptimizationOptions,
pub warm_start: Option<TransformationWarmStart>,
}
pub enum FitRequest<'a> {
Standard(StandardFitRequest<'a>),
GaussianLocationScale(GaussianLocationScaleFitRequest<'a>),
BinomialLocationScale(BinomialLocationScaleFitRequest<'a>),
DispersionLocationScale(DispersionLocationScaleFitRequest<'a>),
SurvivalLocationScale(SurvivalLocationScaleFitRequest<'a>),
SurvivalTransformation(SurvivalTransformationFitRequest<'a>),
BernoulliMarginalSlope(BernoulliMarginalSlopeFitRequest<'a>),
SurvivalMarginalSlope(SurvivalMarginalSlopeFitRequest<'a>),
LatentSurvival(LatentSurvivalFitRequest<'a>),
LatentBinary(LatentBinaryFitRequest<'a>),
TransformationNormal(TransformationNormalFitRequest<'a>),
}
pub struct StandardFitResult {
pub fit: UnifiedFitResult,
pub design: TermCollectionDesign,
pub resolvedspec: TermCollectionSpec,
pub adaptive_diagnostics: Option<AdaptiveRegularizationDiagnostics>,
pub kappa_timing: Option<SpatialLengthScaleOptimizationTiming>,
pub saved_link_state: FittedLinkState,
pub wiggle_knots: Option<Array1<f64>>,
pub wiggle_degree: Option<usize>,
}
pub struct SurvivalLocationScaleFitResult {
pub fit: SurvivalLocationScaleTermFitResult,
pub inverse_link: InverseLink,
pub wiggle_knots: Option<Array1<f64>>,
pub wiggle_degree: Option<usize>,
}
pub struct SurvivalTransformationFitResult {
pub fit: UnifiedFitResult,
pub resolvedspec: TermCollectionSpec,
pub baseline_cfg: crate::families::survival::SurvivalBaselineConfig,
pub likelihood_mode: crate::families::survival::SurvivalLikelihoodMode,
pub time_basis: crate::families::survival::SavedSurvivalTimeBasis,
pub time_base_ncols: usize,
pub baseline_timewiggle: Option<TimeWiggleBlockInput>,
}
pub enum FitResult {
Standard(StandardFitResult),
GaussianLocationScale(GaussianLocationScaleFitResult),
BinomialLocationScale(BinomialLocationScaleFitResult),
DispersionLocationScale(DispersionLocationScaleFitResult),
SurvivalLocationScale(SurvivalLocationScaleFitResult),
SurvivalTransformation(SurvivalTransformationFitResult),
BernoulliMarginalSlope(BernoulliMarginalSlopeFitResult),
SurvivalMarginalSlope(SurvivalMarginalSlopeFitResult),
LatentSurvival(LatentSurvivalTermFitResult),
LatentBinary(LatentBinaryTermFitResult),
TransformationNormal(TransformationNormalFitResult),
SplineScan(crate::solver::spline_scan::SplineScanFit),
ResidualCascade(crate::solver::residual_cascade::ResidualCascadeFit),
}
pub struct DispersionLocationScaleFitResult {
pub fit: BlockwiseTermFitResult,
pub kind: DispersionFamilyKind,
}
pub struct CrossFitScoreCalibration {
pub z_oof: Array1<f64>,
pub jac_oof: Array2<f64>,
}
#[derive(Clone, Debug)]
pub struct CtnStage1Recipe {
pub response_column: String,
pub covariate_formula_rhs: String,
pub config: TransformationNormalConfig,
pub weight_column: Option<String>,
pub offset_column: Option<String>,
}
impl CtnStage1Recipe {
pub fn new(
response: &str,
covariates: &str,
config: TransformationNormalConfig,
weight_column: Option<&str>,
offset_column: Option<&str>,
) -> Result<Self, String> {
let response_column = response.trim().to_string();
if response_column.is_empty() {
return Err("CtnStage1Recipe requires a non-empty Stage-1 response column".to_string());
}
let covariate_formula_rhs = covariates.trim().to_string();
if covariate_formula_rhs.is_empty() {
return Err(
"CtnStage1Recipe requires a non-empty Stage-1 covariate formula RHS".to_string(),
);
}
if covariate_formula_rhs.contains('~') {
return Err(
"CtnStage1Recipe covariates is a right-hand side only; pass 's(pc1) + s(pc2)', \
not 'score ~ s(pc1) + s(pc2)'"
.to_string(),
);
}
Ok(Self {
response_column,
covariate_formula_rhs,
config,
weight_column: weight_column
.map(str::to_string)
.filter(|s| !s.trim().is_empty()),
offset_column: offset_column
.map(str::to_string)
.filter(|s| !s.trim().is_empty()),
})
}
}
#[derive(Clone, Debug)]
pub struct FitConfig {
pub family: Option<String>,
pub negative_binomial_theta: Option<f64>,
pub link: Option<String>,
pub flexible_link: bool,
pub offset_column: Option<String>,
pub noise_offset_column: Option<String>,
pub frailty: Option<FrailtySpec>,
pub baseline_target: String,
pub baseline_scale: Option<f64>,
pub baseline_shape: Option<f64>,
pub baseline_rate: Option<f64>,
pub baseline_makeham: Option<f64>,
pub time_basis: String,
pub time_degree: usize,
pub time_num_internal_knots: usize,
pub time_smooth_lambda: f64,
pub survival_likelihood: String,
pub survival_distribution: String,
pub threshold_time_k: Option<usize>,
pub threshold_time_degree: usize,
pub sigma_time_k: Option<usize>,
pub sigma_time_degree: usize,
pub noise_formula: Option<String>,
pub logslope_formula: Option<String>,
pub z_column: Option<String>,
pub weight_column: Option<String>,
pub expectile_tau: Option<f64>,
pub ctn_stage1: Option<CtnStage1Recipe>,
pub scale_dimensions: bool,
pub adaptive_regularization: Option<bool>,
pub ridge_lambda: f64,
pub transformation_normal: bool,
pub firth: bool,
pub outer_max_iter: Option<usize>,
pub gpu_policy: crate::gpu::GpuPolicy,
pub resource_policy: Option<crate::resource::ResourcePolicy>,
pub group_metadata: Option<BTreeMap<String, JsonValue>>,
pub coefficient_groups: Vec<CoefficientGroupSpec>,
pub penalty_block_gamma_priors: Vec<(String, f64, f64)>,
pub latents: Option<JsonValue>,
pub analytic_penalties: Option<JsonValue>,
pub topology_auto_selector: Option<crate::solver::topology_selector::TopologyAutoSelector>,
pub smooth_overrides: Option<JsonValue>,
pub persist_warm_start_disk: bool,
}
impl Default for FitConfig {
fn default() -> Self {
Self {
family: None,
negative_binomial_theta: None,
link: None,
flexible_link: false,
offset_column: None,
noise_offset_column: None,
frailty: None,
baseline_target: "linear".into(),
baseline_scale: None,
baseline_shape: None,
baseline_rate: None,
baseline_makeham: None,
time_basis: "ispline".into(),
time_degree: 3,
time_num_internal_knots: 8,
time_smooth_lambda: 1e-2,
survival_likelihood: "transformation".into(),
survival_distribution: "gaussian".into(),
threshold_time_k: None,
threshold_time_degree: 3,
sigma_time_k: None,
sigma_time_degree: 3,
noise_formula: None,
logslope_formula: None,
z_column: None,
weight_column: None,
expectile_tau: None,
ctn_stage1: None,
scale_dimensions: false,
adaptive_regularization: None,
ridge_lambda: 1e-6,
transformation_normal: false,
firth: false,
outer_max_iter: None,
gpu_policy: crate::gpu::GpuPolicy::Auto,
resource_policy: None,
group_metadata: None,
coefficient_groups: Vec::new(),
penalty_block_gamma_priors: Vec::new(),
latents: None,
analytic_penalties: None,
topology_auto_selector: None,
smooth_overrides: None,
persist_warm_start_disk: false,
}
}
}
pub struct MaterializedModel<'a> {
pub request: FitRequest<'a>,
pub inference_notes: Vec<String>,
}
pub struct SplineScanInputs {
pub x: Vec<f64>,
pub y: Vec<f64>,
pub w: Vec<f64>,
pub order: usize,
}
pub struct ResidualCascadeInputs {
pub coords: Vec<Vec<f64>>,
pub y: Vec<f64>,
pub w: Vec<f64>,
pub metric: Vec<f64>,
pub sobolev_s: f64,
}