use super::*;
pub fn fit_model(request: FitRequest<'_>) -> Result<FitResult, WorkflowError> {
let mut request = request;
let exact_key = request.cache_key();
let seed_key = request.cache_seed_key();
if let Some(session) = crate::solver::persistent_warm_start::open_outer_session(&exact_key) {
let exact_present = session.peek_load().is_some();
if !exact_present
&& let Some(seed) =
crate::solver::persistent_warm_start::lookup_outer_iterate_payload(&seed_key)
{
let prior_obj = seed.objective.unwrap_or(f64::NAN);
log::info!(
"[CACHE] seed key={}.. via prefix family={} prior_obj={:.6e}",
&exact_key[..8.min(exact_key.len())],
request.family_tag(),
prior_obj,
);
session.preload(seed);
}
request.attach_cache_session(session);
}
let mirror_session = crate::solver::persistent_warm_start::open_outer_session(&seed_key);
if let Some(mirror) = mirror_session.as_ref() {
request.attach_cache_mirror(Arc::clone(mirror));
}
let wrap_solver_err =
|reason: String| -> WorkflowError { WorkflowError::IntegrationFailed { reason } };
match request {
FitRequest::Standard(request) => fit_standard_model(request)
.map(FitResult::Standard)
.map_err(wrap_solver_err),
FitRequest::GaussianLocationScale(request) => fit_gaussian_location_scale_model(request)
.map(FitResult::GaussianLocationScale)
.map_err(wrap_solver_err),
FitRequest::BinomialLocationScale(request) => fit_binomial_location_scale_model(request)
.map(FitResult::BinomialLocationScale)
.map_err(wrap_solver_err),
FitRequest::DispersionLocationScale(request) => {
fit_dispersion_location_scale_model(request)
.map(FitResult::DispersionLocationScale)
.map_err(wrap_solver_err)
}
FitRequest::SurvivalLocationScale(request) => {
match fit_survival_location_scale_model(request).map(FitResult::SurvivalLocationScale) {
Ok(fit) => Ok(fit),
Err(e)
if e.contains("expects 3 blocks, got 0")
|| e.contains("expects 4 blocks, got 0")
|| (e.contains("block_states") && e.contains("got 0"))
|| e.contains("blockwise fit requires at least one block state") =>
{
Err(WorkflowError::IntegrationFailed {
reason: format!(
"survival location-scale fit failed: the smoothing-parameter optimizer \
landed at a degenerate iterate where the inner solver's block state \
was empty. This is the symptom of an under-identified smooth driven \
to a numerically pathological λ (e.g. exp(20+)) on a small-data \
subsample. Try: (1) reducing covariate count, (2) increasing n_train, \
(3) `baseline_target=\"linear\"` to drop the parametric baseline, or \
(4) `noise_formula=\"1\"` to drop the noise GAM. Underlying error: {e}"
),
})
}
Err(reason) => Err(wrap_solver_err(reason)),
}
}
FitRequest::SurvivalTransformation(request) => fit_survival_transformation_model(request)
.map(FitResult::SurvivalTransformation)
.map_err(wrap_solver_err),
FitRequest::BernoulliMarginalSlope(request) => fit_bernoulli_marginal_slope_model(request)
.map(FitResult::BernoulliMarginalSlope)
.map_err(wrap_solver_err),
FitRequest::SurvivalMarginalSlope(request) => fit_survival_marginal_slope_model(request)
.map(FitResult::SurvivalMarginalSlope)
.map_err(wrap_solver_err),
FitRequest::LatentSurvival(request) => fit_latent_survival_model(request)
.map(FitResult::LatentSurvival)
.map_err(wrap_solver_err),
FitRequest::LatentBinary(request) => fit_latent_binary_model(request)
.map(FitResult::LatentBinary)
.map_err(wrap_solver_err),
FitRequest::TransformationNormal(request) => fit_transformation_normal_model(request)
.map(FitResult::TransformationNormal)
.map_err(wrap_solver_err),
}
}
pub(crate) fn resolved_resource_policy(
config: &FitConfig,
data: &Dataset,
hints: crate::resource::ProblemHints,
) -> crate::resource::ResourcePolicy {
if let Some(p) = config.resource_policy.clone() {
return p;
}
crate::resource::ResourcePolicy::for_problem(data.values.nrows(), 0, hints)
}
pub(crate) fn marginal_slope_hints(config: &FitConfig) -> crate::resource::ProblemHints {
crate::resource::ProblemHints {
marginal_slope_large_scale_active: config.logslope_formula.is_some()
|| config.z_column.is_some(),
}
}
fn expectile_tau_for_config(config: &FitConfig) -> Result<Option<f64>, WorkflowError> {
let Some(raw) = config.family.as_deref() else {
return Ok(None);
};
let trimmed = raw.trim();
let lower = trimmed.to_ascii_lowercase();
if !(lower == "expectile" || lower.starts_with("expectile(")) {
return Ok(None);
}
let invalid = |reason: String| WorkflowError::InvalidConfig { reason };
let inline_tau = if let Some(rest) = lower.strip_prefix("expectile(") {
let inner = rest.strip_suffix(')').ok_or_else(|| {
invalid(format!(
"expectile family asymmetry must be written as `expectile(τ)`; got `{trimmed}`"
))
})?;
let value: f64 = inner.trim().parse().map_err(|_| {
invalid(format!(
"expectile asymmetry `{}` is not a finite number",
inner.trim()
))
})?;
Some(value)
} else {
None
};
let tau = match (inline_tau, config.expectile_tau) {
(Some(a), Some(b)) if (a - b).abs() > 0.0 => {
return Err(invalid(format!(
"expectile asymmetry given both inline (`expectile({a})`) and via expectile_tau \
({b}); supply exactly one"
)));
}
(Some(a), _) => a,
(None, Some(b)) => b,
(None, None) => 0.5,
};
if !(tau.is_finite() && tau > 0.0 && tau < 1.0) {
return Err(invalid(format!(
"expectile asymmetry τ must be finite and strictly in (0, 1); got {tau}"
)));
}
Ok(Some(tau))
}
fn expectile_row_weights(
y: ArrayView1<f64>,
mu: ArrayView1<f64>,
base: ArrayView1<f64>,
tau: f64,
) -> Array1<f64> {
Array1::from_shape_fn(y.len(), |i| {
let asym = if y[i] > mu[i] { tau } else { 1.0 - tau };
base[i] * asym
})
}
pub fn fit_from_formula(
formula: &str,
data: &Dataset,
config: &FitConfig,
) -> Result<FitResult, WorkflowError> {
if let Some(tau) = expectile_tau_for_config(config)? {
return fit_expectile_laws(formula, data, config, tau);
}
let mat = materialize(formula, data, config)?;
if let FitRequest::Standard(request) = &mat.request {
if let Some(inputs) = spline_scan_fast_path(request) {
let scan = crate::solver::spline_scan::fit_spline_scan(
&inputs.x,
&inputs.y,
&inputs.w,
inputs.order,
)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })?;
return Ok(FitResult::SplineScan(scan));
}
if let Some(inputs) = residual_cascade_fast_path(request) {
let coord_refs: Vec<&[f64]> = inputs.coords.iter().map(Vec::as_slice).collect();
if let Ok(fit) = crate::solver::residual_cascade::fit_residual_cascade(
&coord_refs,
&inputs.y,
&inputs.w,
&inputs.metric,
inputs.sobolev_s,
) {
return Ok(FitResult::ResidualCascade(fit));
}
}
}
fit_model(mat.request)
}
fn fit_expectile_laws(
formula: &str,
data: &Dataset,
config: &FitConfig,
tau: f64,
) -> Result<FitResult, WorkflowError> {
use crate::linalg::matrix::LinearOperator;
let gaussian_config = FitConfig {
family: Some("gaussian".to_string()),
link: Some("identity".to_string()),
expectile_tau: None,
..config.clone()
};
let base_mat = materialize(formula, data, &gaussian_config)?;
let FitRequest::Standard(base_request) = base_mat.request else {
return Err(WorkflowError::InvalidConfig {
reason: "expectile regression is only defined for standard (non-survival, \
non-location-scale) responses"
.to_string(),
});
};
let StandardFitRequest {
data: design_data,
y,
weights: base_weights,
offset,
spec,
family: materialized_family,
options,
kappa_options,
wiggle,
coefficient_groups,
penalty_block_gamma_priors,
latent_coord,
_marker,
} = base_request;
if !materialized_family.is_gaussian_identity() {
return Err(WorkflowError::InvalidConfig {
reason: format!(
"expectile LAWS requires a Gaussian-identity inner family; materializer produced {}",
materialized_family.name()
),
});
}
if wiggle.is_some() || latent_coord.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "expectile regression does not support flexible-link wiggle or latent \
coordinates"
.to_string(),
});
}
let n = y.len();
let gaussian_family = LikelihoodSpec::gaussian_identity();
let mut weights = base_weights.clone();
let mut last_sign: Option<Vec<bool>> = None;
let mut last_result: Option<StandardFitResult> = None;
const MAX_LAWS_ITERS: usize = 50;
for _iter in 0..MAX_LAWS_ITERS {
let request = StandardFitRequest {
data: design_data.clone(),
y: y.clone(),
weights: weights.clone(),
offset: offset.clone(),
spec: spec.clone(),
family: gaussian_family.clone(),
options: options.clone(),
kappa_options: kappa_options.clone(),
wiggle: None,
coefficient_groups: coefficient_groups.clone(),
penalty_block_gamma_priors: penalty_block_gamma_priors.clone(),
latent_coord: None,
_marker,
};
let result = fit_standard_model(request)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })?;
let mu = result.design.design.apply(&result.fit.beta);
if mu.len() != n {
return Err(WorkflowError::IntegrationFailed {
reason: format!(
"expectile LAWS: fitted mean length {} disagrees with response length {n}",
mu.len()
),
});
}
let mut mu_off = mu;
mu_off += &offset;
let sign: Vec<bool> = (0..n).map(|i| y[i] > mu_off[i]).collect();
let converged = last_sign.as_ref().is_some_and(|prev| prev == &sign);
weights = expectile_row_weights(y.view(), mu_off.view(), base_weights.view(), tau);
last_sign = Some(sign);
last_result = Some(result);
if converged {
break;
}
}
let result = last_result.ok_or_else(|| WorkflowError::IntegrationFailed {
reason: "expectile LAWS produced no fit".to_string(),
})?;
Ok(FitResult::Standard(result))
}
pub fn spline_scan_fast_path(request: &StandardFitRequest<'_>) -> Option<SplineScanInputs> {
if !request.family.is_gaussian_identity() {
return None;
}
if request.wiggle.is_some()
|| request.latent_coord.is_some()
|| !request.coefficient_groups.is_empty()
|| !request.penalty_block_gamma_priors.is_empty()
{
return None;
}
let options = &request.options;
if options.latent_cloglog.is_some()
|| options.mixture_link.is_some()
|| options.sas_link.is_some()
|| options.linear_constraints.is_some()
|| options.adaptive_regularization.is_some()
|| options.kronecker_penalty_system.is_some()
|| options.kronecker_factored.is_some()
|| options.firth_bias_reduction
|| !options.nullspace_dims.is_empty()
{
return None;
}
let spec = &request.spec;
if !spec.linear_terms.is_empty()
|| !spec.random_effect_terms.is_empty()
|| spec.smooth_terms.len() != 1
{
return None;
}
let term = &spec.smooth_terms[0];
if !matches!(term.shape, crate::smooth::ShapeConstraint::None)
|| term.joint_null_rotation.is_some()
{
return None;
}
let crate::smooth::SmoothBasisSpec::BSpline1D {
feature_col,
spec: bspec,
} = &term.basis
else {
return None;
};
let order = bspec.penalty_order;
if !(1..=3).contains(&order)
|| bspec.degree != 2 * order - 1
|| bspec.double_penalty
|| !bspec.boundary_conditions.is_free()
|| !matches!(bspec.boundary, crate::basis::OneDimensionalBoundary::Open)
|| matches!(
bspec.knotspec,
crate::basis::BSplineKnotSpec::PeriodicUniform { .. }
)
{
return None;
}
if request.offset.iter().any(|&v| v != 0.0) {
return None;
}
if request.weights.iter().any(|&v| !(v.is_finite() && v > 0.0)) {
return None;
}
if *feature_col >= request.data.ncols() || request.y.len() != request.data.nrows() {
return None;
}
let x: Vec<f64> = request.data.column(*feature_col).iter().copied().collect();
let y: Vec<f64> = request.y.iter().copied().collect();
let w: Vec<f64> = request.weights.iter().copied().collect();
if x.iter().any(|v| !v.is_finite()) || y.iter().any(|v| !v.is_finite()) {
return None;
}
let mut sorted = x.clone();
sorted.sort_by(f64::total_cmp);
sorted.dedup();
if sorted.len() < order + 1 {
return None;
}
Some(SplineScanInputs { x, y, w, order })
}
pub fn fit_spline_scan_from_formula(
formula: &str,
data: &Dataset,
config: &FitConfig,
) -> Result<Option<crate::solver::spline_scan::SplineScanFit>, WorkflowError> {
let mat = materialize(formula, data, config)?;
let FitRequest::Standard(request) = mat.request else {
return Ok(None);
};
let Some(inputs) = spline_scan_fast_path(&request) else {
return Ok(None);
};
crate::solver::spline_scan::fit_spline_scan(&inputs.x, &inputs.y, &inputs.w, inputs.order)
.map(Some)
.map_err(|reason| WorkflowError::IntegrationFailed { reason })
}
fn past_dense_kernel_cliff(n: usize, d: usize) -> bool {
const DENSE_CENTER_CAP: usize = 2000;
crate::terms::basis::default_num_centers(n, d) >= DENSE_CENTER_CAP
}
fn cascade_sobolev_order(requested: f64, d: usize) -> f64 {
let lo = d as f64 / 2.0;
let hi = (d as f64 + 3.0) / 2.0;
let eps = 1e-6 * (hi - lo);
requested.clamp(lo + eps, hi)
}
pub fn residual_cascade_fast_path(
request: &StandardFitRequest<'_>,
) -> Option<ResidualCascadeInputs> {
if !request.family.is_gaussian_identity() {
return None;
}
if request.wiggle.is_some()
|| request.latent_coord.is_some()
|| !request.coefficient_groups.is_empty()
|| !request.penalty_block_gamma_priors.is_empty()
{
return None;
}
let options = &request.options;
if options.latent_cloglog.is_some()
|| options.mixture_link.is_some()
|| options.sas_link.is_some()
|| options.linear_constraints.is_some()
|| options.adaptive_regularization.is_some()
|| options.kronecker_penalty_system.is_some()
|| options.kronecker_factored.is_some()
|| options.firth_bias_reduction
|| !options.nullspace_dims.is_empty()
{
return None;
}
let spec = &request.spec;
if !spec.linear_terms.is_empty()
|| !spec.random_effect_terms.is_empty()
|| spec.smooth_terms.len() != 1
{
return None;
}
let term = &spec.smooth_terms[0];
if !matches!(term.shape, crate::smooth::ShapeConstraint::None)
|| term.joint_null_rotation.is_some()
{
return None;
}
let (feature_cols, requested_s) = match &term.basis {
crate::smooth::SmoothBasisSpec::Duchon {
feature_cols, spec, ..
} => {
let p = match spec.nullspace_order {
crate::basis::DuchonNullspaceOrder::Zero => 0.0,
crate::basis::DuchonNullspaceOrder::Linear => 1.0,
crate::basis::DuchonNullspaceOrder::Degree(k) => k as f64,
};
(feature_cols, spec.power + p)
}
crate::smooth::SmoothBasisSpec::Matern {
feature_cols, spec, ..
} => {
let nu = spec.nu.half_integer_value();
(feature_cols, nu + feature_cols.len() as f64 / 2.0)
}
_ => return None,
};
let d = feature_cols.len();
if !(2..=3).contains(&d) {
return None;
}
if request.offset.iter().any(|&v| v != 0.0) {
return None;
}
if request.weights.iter().any(|&v| !(v.is_finite() && v > 0.0)) {
return None;
}
let n = request.y.len();
if n != request.data.nrows() || feature_cols.iter().any(|&c| c >= request.data.ncols()) {
return None;
}
if !past_dense_kernel_cliff(n, d) {
return None;
}
let coords: Vec<Vec<f64>> = feature_cols
.iter()
.map(|&c| request.data.column(c).iter().copied().collect())
.collect();
let y: Vec<f64> = request.y.iter().copied().collect();
let w: Vec<f64> = request.weights.iter().copied().collect();
if coords
.iter()
.any(|axis| axis.iter().any(|v| !v.is_finite()))
|| y.iter().any(|v| !v.is_finite())
{
return None;
}
let metric = vec![1.0_f64; d];
let sobolev_s = cascade_sobolev_order(requested_s, d);
Some(ResidualCascadeInputs {
coords,
y,
w,
metric,
sobolev_s,
})
}
pub fn fit_residual_cascade_from_formula(
formula: &str,
data: &Dataset,
config: &FitConfig,
) -> Result<Option<crate::solver::residual_cascade::ResidualCascadeFit>, WorkflowError> {
let mat = materialize(formula, data, config)?;
let FitRequest::Standard(request) = mat.request else {
return Ok(None);
};
let Some(inputs) = residual_cascade_fast_path(&request) else {
return Ok(None);
};
let coord_refs: Vec<&[f64]> = inputs.coords.iter().map(Vec::as_slice).collect();
match crate::solver::residual_cascade::fit_residual_cascade(
&coord_refs,
&inputs.y,
&inputs.w,
&inputs.metric,
inputs.sobolev_s,
) {
Ok(fit) => Ok(Some(fit)),
Err(_) => Ok(None),
}
}
pub fn materialize<'a>(
formula: &str,
data: &'a Dataset,
config: &FitConfig,
) -> Result<MaterializedModel<'a>, WorkflowError> {
crate::gpu::configure_global_policy(config.gpu_policy);
let parsed = parse_formula(formula)?;
let col_map = data.column_map();
if let Some((left_col, right_col, event_col)) = parse_surv_interval_response(&parsed.response)?
{
if config.transformation_normal {
return Err(WorkflowError::InvalidConfig {
reason:
"transformation_normal cannot be combined with a SurvInterval(...) response"
.to_string(),
});
}
materialize_survival(
&parsed,
data,
&col_map,
config,
None,
&left_col,
&event_col,
Some(&right_col),
)
} else if let Some((entry_col, exit_col, event_col)) = parse_surv_response(&parsed.response)? {
if config.transformation_normal {
return Err(WorkflowError::InvalidConfig {
reason: "transformation_normal cannot be combined with a Surv(...) response"
.to_string(),
});
}
materialize_survival(
&parsed,
data,
&col_map,
config,
entry_col.as_deref(),
&exit_col,
&event_col,
None,
)
} else {
reject_survival_only_terms_for_nonsurvival(&parsed)?;
if config.transformation_normal {
reject_marginal_slope_controls_for_transformation_normal(config)?;
if config.noise_formula.is_some() {
return Err(WorkflowError::InvalidConfig {
reason: "transformation_normal cannot be combined with noise_formula"
.to_string(),
});
}
materialize_transformation_normal(&parsed, data, &col_map, config)
} else if config.logslope_formula.is_some() || config.z_column.is_some() {
materialize_bernoulli_marginal_slope(&parsed, data, &col_map, config)
} else if config.noise_formula.is_some() {
materialize_location_scale(&parsed, data, &col_map, config)
} else {
materialize_standard(&parsed, data, &col_map, config)
}
}
}