use super::*;
use crate::basis::{DuchonNullspaceOrder, minimum_duchon_power_for_operator_penalties};
use crate::inference::data::load_dataset_projected;
use crate::inference::formula_dsl::{default_linkwiggle_formulaspec, parse_linkwiggle_formulaspec};
use crate::inference::model::{ColumnKindTag, DataSchema, SchemaColumn};
use crate::smooth::SmoothBasisSpec;
use crate::solver::rho_optimizer::{HessianSource, OuterPlan, OuterResult, Solver};
use ndarray::Array2;
use std::fs;
use tempfile::tempdir;
fn load_survival_dataset() -> crate::inference::data::EncodedDataset {
let td = tempdir().expect("tempdir");
let data_path = td.path().join("survival.csv");
fs::write(
&data_path,
"entry,exit,event,x,z\n0.0,1.0,1,0.2,-0.4\n0.3,1.6,0,-0.1,0.6\n",
)
.expect("write survival csv");
load_dataset_projected(
&data_path,
&[
"entry".to_string(),
"exit".to_string(),
"event".to_string(),
"x".to_string(),
"z".to_string(),
],
)
.expect("load survival dataset")
}
#[test]
fn competing_risks_baseline_seed_replicates_to_match_cause_specific_beta_length() {
let pooled = Array1::from_vec(vec![-1.5_f64, 0.8, 0.0]);
let p = pooled.len();
for cause_count in [1usize, 2, 3] {
let flat = replicate_pooled_baseline_seed_per_cause(pooled.view(), cause_count);
assert_eq!(
flat.len(),
p * cause_count,
"replicated seed must satisfy the `p * cause_count` length contract"
);
for cause in 0..cause_count {
let slice = flat.slice(s![cause * p..(cause + 1) * p]);
assert_eq!(
slice.to_owned(),
pooled,
"cause {cause} block must be seeded from the pooled baseline"
);
}
}
}
#[test]
fn survival_marginal_slope_materialize_rejects_z_column_in_main_formula() {
let data = load_survival_dataset();
let mut config = FitConfig::default();
config.survival_likelihood = "marginal-slope".to_string();
config.logslope_formula = Some("1".to_string());
config.z_column = Some("z".to_string());
let err = materialize("Surv(entry, exit, event) ~ x + z", &data, &config)
.err()
.expect("main formula should reject z-column reuse");
assert!(
err.to_string()
.contains("survival marginal-slope reserves z column 'z'")
);
assert!(err.to_string().contains("main formula"));
}
#[test]
fn survival_marginal_slope_materialize_rejects_z_column_in_logslope_formula() {
let data = load_survival_dataset();
let mut config = FitConfig::default();
config.survival_likelihood = "marginal-slope".to_string();
config.logslope_formula = Some("1 + z".to_string());
config.z_column = Some("z".to_string());
let err = materialize("Surv(entry, exit, event) ~ x", &data, &config)
.err()
.expect("logslope formula should reject z-column reuse");
assert!(
err.to_string()
.contains("survival marginal-slope reserves z column 'z'")
);
assert!(err.to_string().contains("logslope_formula"));
}
#[test]
fn survival_marginal_slope_materialize_rejects_z_column_when_logslope_defaults_to_main_spec() {
let data = load_survival_dataset();
let mut config = FitConfig::default();
config.survival_likelihood = "marginal-slope".to_string();
config.z_column = Some("z".to_string());
let err = materialize("Surv(entry, exit, event) ~ x + z", &data, &config)
.err()
.expect("defaulted logslope spec should still reject z-column reuse");
assert!(
err.to_string()
.contains("survival marginal-slope reserves z column 'z'")
);
assert!(err.to_string().contains("main formula"));
}
#[test]
fn survival_marginal_slope_matern_logslope_penalties_keep_surface_width() {
let n = 24usize;
let mut values = Array2::<f64>::zeros((n, 8));
for i in 0..n {
let u = i as f64 / (n - 1) as f64;
values[[i, 0]] = 0.0;
values[[i, 1]] = 0.25 + 8.0 * u;
values[[i, 2]] = if i % 3 == 0 { 1.0 } else { 0.0 };
values[[i, 3]] = ((i * 17 % 23) as f64 - 11.0) / 7.0;
values[[i, 4]] = (2.0 * std::f64::consts::PI * u).sin();
values[[i, 5]] = (2.0 * std::f64::consts::PI * u).cos();
values[[i, 6]] = 2.0 * u - 1.0;
values[[i, 7]] = if i % 2 == 0 { 0.0 } else { 1.0 };
}
let data = Dataset {
headers: vec![
"t0".to_string(),
"t1".to_string(),
"event".to_string(),
"z".to_string(),
"PC1".to_string(),
"PC2".to_string(),
"PC3".to_string(),
"sex".to_string(),
],
values,
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "t0".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "t1".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "event".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "z".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC1".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC2".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC3".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "sex".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
],
},
column_kinds: vec![
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Binary,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Binary,
],
};
for (case, formula) in [
(
"with parametric sex term",
"Surv(t0, t1, event) ~ matern(PC1, PC2, PC3, centers=6) + sex",
),
(
"without parametric sex term",
"Surv(t0, t1, event) ~ matern(PC1, PC2, PC3, centers=6)",
),
] {
let config = FitConfig {
survival_likelihood: "marginal-slope".to_string(),
logslope_formula: Some("matern(PC1, PC2, PC3, centers=6)".to_string()),
z_column: Some("z".to_string()),
..FitConfig::default()
};
let materialized = materialize(formula, &data, &config).unwrap_or_else(|err| {
panic!(
"survival marginal-slope materialization should keep block-local penalties \
{case}: {err}"
)
});
let FitRequest::SurvivalMarginalSlope(request) = materialized.request else {
panic!("expected survival marginal-slope request for {case}");
};
let specs = vec![
request.spec.marginalspec.clone(),
request.spec.logslopespec.clone(),
];
let (designs, frozen_specs) =
crate::smooth::build_term_collection_designs_and_freeze_joint(
data.values.view(),
&specs,
)
.unwrap_or_else(|err| {
panic!("joint freeze should preserve per-block penalty geometry {case}: {err}")
});
let (rebuilt, _) = crate::smooth::build_term_collection_designs_and_freeze_joint(
data.values.view(),
&frozen_specs,
)
.unwrap_or_else(|err| {
panic!("frozen rebuild should preserve per-block penalty geometry {case}: {err}")
});
for (label, design) in [
("raw marginal", &designs[0]),
("raw logslope", &designs[1]),
("frozen marginal", &rebuilt[0]),
("frozen logslope", &rebuilt[1]),
] {
let width = design.design.ncols();
assert!(
width > 2,
"{case}: {label} design should be surface-width, not sex/intercept-width; \
width={width}"
);
for (idx, penalty) in design.penalties_as_penalty_matrix().iter().enumerate() {
assert_eq!(
penalty.shape(),
(width, width),
"{case}: {label} penalty {idx} must be block-local at the surface width"
);
}
}
}
}
fn workflow_test_dataset() -> Dataset {
Dataset {
headers: vec![
"age_entry".to_string(),
"age_exit".to_string(),
"event".to_string(),
"bmi".to_string(),
"z".to_string(),
],
values: Array2::from_shape_vec(
(4, 5),
vec![
40.0, 43.0, 1.0, 22.0, -1.0, 41.0, 46.0, 0.0, 24.0, -0.2, 42.0, 47.0, 1.0, 27.0,
0.3, 44.0, 49.0, 0.0, 29.0, 1.2,
],
)
.expect("workflow test data shape"),
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "age_entry".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "age_exit".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "event".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "bmi".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "z".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Binary,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
],
}
}
#[test]
fn issue_789_transformation_normal_rejects_marginal_slope_controls_before_dispatch() {
let data = workflow_test_dataset();
let config = FitConfig {
transformation_normal: true,
family: Some("bernoulli-marginal-slope".to_string()),
logslope_formula: Some("1".to_string()),
z_column: Some("z".to_string()),
..FitConfig::default()
};
let err = materialize("event ~ bmi", &data, &config)
.err()
.expect("transformation_normal must not steal marginal-slope fits");
assert!(
err.to_string()
.contains("transformation_normal cannot be combined with marginal-slope")
);
}
#[test]
fn survival_marginal_slope_rejects_zero_event_data_before_fit() {
let mut data = workflow_test_dataset();
data.values.column_mut(2).fill(0.0);
let config = FitConfig {
survival_likelihood: "marginal-slope".to_string(),
logslope_formula: Some("1".to_string()),
z_column: Some("z".to_string()),
..FitConfig::default()
};
let err = materialize("Surv(age_entry, age_exit, event) ~ bmi", &data, &config)
.err()
.expect("zero-event survival marginal-slope data must fail before optimization");
assert!(err.to_string().contains("at least one target event"));
}
fn workflow_test_outer_result(converged: bool, rho: Array1<f64>) -> OuterResult {
let mut result = OuterResult::new(
rho,
1.25,
7,
converged,
OuterPlan {
solver: Solver::Bfgs,
hessian_source: HessianSource::BfgsApprox,
},
);
result.final_grad_norm = Some(0.5);
result
}
fn duchon_workflow_dataset() -> Dataset {
let n = 72usize;
let mut values = Array2::<f64>::zeros((n, 3));
for i in 0..n {
let t = 2.0 * std::f64::consts::PI * i as f64 / n as f64;
values[[i, 0]] = 0.5 * t.sin() + 0.15 * (3.0 * t).cos();
values[[i, 1]] = t.cos();
values[[i, 2]] = t.sin();
}
Dataset {
headers: vec!["y".to_string(), "ct".to_string(), "st".to_string()],
values,
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "y".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "ct".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "st".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
],
}
}
#[test]
fn materialize_standard_keeps_adaptive_regularization_off_by_default_for_duchon() {
let data = duchon_workflow_dataset();
let materialized = materialize(
"y ~ duchon(ct, st, centers=12)",
&data,
&FitConfig::default(),
)
.expect("Duchon standard materialization should succeed");
let FitRequest::Standard(request) = materialized.request else {
panic!("expected standard request");
};
assert!(request.options.adaptive_regularization.is_none());
}
#[test]
fn materialize_standard_honors_adaptive_regularization_enable() {
let data = duchon_workflow_dataset();
let config = FitConfig {
adaptive_regularization: Some(true),
..FitConfig::default()
};
let materialized = materialize("y ~ duchon(ct, st, centers=12)", &data, &config)
.expect("Duchon materialization should allow enabling adaptive regularization");
let FitRequest::Standard(request) = materialized.request else {
panic!("expected standard request");
};
let opts = request
.options
.adaptive_regularization
.expect("Duchon should enable adaptive regularization when requested");
assert!(opts.enabled);
}
#[test]
fn materialize_standard_honors_adaptive_regularization_disable() {
let data = duchon_workflow_dataset();
let config = FitConfig {
adaptive_regularization: Some(false),
..FitConfig::default()
};
let materialized = materialize("y ~ duchon(ct, st, centers=12)", &data, &config)
.expect("Duchon materialization should allow disabling adaptive regularization");
let FitRequest::Standard(request) = materialized.request else {
panic!("expected standard request");
};
assert!(request.options.adaptive_regularization.is_none());
}
#[test]
fn materialize_standard_duchon_defaults_to_pure_scale_free_basis() {
let data = duchon_workflow_dataset();
let materialized = materialize(
"y ~ duchon(ct, st, centers=12)",
&data,
&FitConfig::default(),
)
.expect("Duchon materialization should succeed");
let FitRequest::Standard(request) = materialized.request else {
panic!("expected standard request");
};
let SmoothBasisSpec::Duchon { spec, .. } = &request.spec.smooth_terms[0].basis else {
panic!("expected Duchon smooth");
};
assert_eq!(spec.length_scale, None);
assert!(matches!(spec.nullspace_order, DuchonNullspaceOrder::Linear));
assert_eq!(spec.power, 0.5);
}
#[test]
fn materialize_standard_duchon_length_scale_opts_into_hybrid_basis() {
let data = duchon_workflow_dataset();
let materialized = materialize(
"y ~ duchon(ct, st, centers=12, length_scale=1.0)",
&data,
&FitConfig::default(),
)
.expect("hybrid Duchon materialization should succeed");
let FitRequest::Standard(request) = materialized.request else {
panic!("expected standard request");
};
let SmoothBasisSpec::Duchon { spec, .. } = &request.spec.smooth_terms[0].basis else {
panic!("expected Duchon smooth");
};
assert_eq!(spec.length_scale, Some(1.0));
assert_eq!(spec.nullspace_order, DuchonNullspaceOrder::Linear);
assert_eq!(spec.power, 0.0);
}
#[test]
fn workflow_survival_marginal_slope_routes_logslope_linkwiggle_into_score_warp_only() {
let data = workflow_test_dataset();
let config = FitConfig {
survival_likelihood: "marginal-slope".to_string(),
logslope_formula: Some(
"1 + linkwiggle(degree=5, internal_knots=7, penalty_order=\"2,3\")".to_string(),
),
z_column: Some("z".to_string()),
..FitConfig::default()
};
let materialized = materialize(
"Surv(age_entry, age_exit, event) ~ s(bmi) + linkwiggle(degree=4, internal_knots=9, penalty_order=\"1\")",
&data,
&config,
)
.expect("workflow materialization should succeed");
let MaterializedModel {
request,
inference_notes,
} = materialized;
let FitRequest::SurvivalMarginalSlope(request) = request else {
panic!("expected survival marginal-slope request");
};
let link_dev = request.spec.link_dev.expect("main-formula link-dev");
let score_warp = request.spec.score_warp.expect("logslope score-warp");
assert_eq!(link_dev.degree, 4);
assert_eq!(link_dev.num_internal_knots, 9);
assert_eq!(link_dev.penalty_order, 1);
assert_eq!(link_dev.penalty_orders, vec![1]);
assert_eq!(score_warp.degree, 5);
assert_eq!(score_warp.num_internal_knots, 7);
assert_eq!(score_warp.penalty_order, 3);
assert_eq!(score_warp.penalty_orders, vec![2, 3]);
assert!(
inference_notes
.iter()
.any(|note| note.contains("link-deviation block")),
"workflow notes should mention main-formula linkwiggle routing"
);
assert!(
inference_notes
.iter()
.any(|note| note.contains("score-warp block")),
"workflow notes should mention logslope_formula linkwiggle routing"
);
}
#[test]
fn materialize_routes_bernoulli_marginal_slope_when_logslope_and_z_are_set() {
let data = workflow_test_dataset();
let config = FitConfig {
logslope_formula: Some("1".to_string()),
z_column: Some("z".to_string()),
..FitConfig::default()
};
let materialized = materialize("event ~ bmi", &data, &config)
.expect("Bernoulli marginal-slope materialization should succeed");
assert!(matches!(
materialized.request,
FitRequest::BernoulliMarginalSlope(_)
));
}
#[test]
fn materialize_bernoulli_marginal_slope_prunes_redundant_scalar_term() {
let data = Dataset {
headers: vec![
"event".to_string(),
"x".to_string(),
"constant_spline_col".to_string(),
"prs_z".to_string(),
"PC1".to_string(),
"PC2".to_string(),
"PC3".to_string(),
],
values: Array2::from_shape_vec(
(6, 7),
vec![
0.0, -2.0, 1.0, -1.2, -1.0, 0.2, 0.7, 1.0, -1.0, 1.0, -0.4, -0.4, -0.3, 0.5, 0.0,
0.0, 1.0, 0.1, 0.1, 0.4, -0.2, 1.0, 1.0, 1.0, 0.5, 0.7, -0.6, 0.3, 0.0, 2.0, 1.0,
1.1, 1.2, 0.9, 0.0, 1.0, 3.0, 1.0, 1.7, 1.6, -0.8, -0.4,
],
)
.expect("BMS redundant scalar test data shape"),
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "event".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "x".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "constant_spline_col".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "prs_z".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC1".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC2".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC3".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![
ColumnKindTag::Binary,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
],
};
let config = FitConfig {
logslope_formula: Some("matern(PC1, PC2, PC3, centers=3)".to_string()),
z_column: Some("prs_z".to_string()),
..FitConfig::default()
};
let materialized = materialize(
"event ~ matern(PC1, PC2, PC3, centers=3) + x + constant_spline_col",
&data,
&config,
)
.expect("BMS materialization should prune the redundant scalar term");
let MaterializedModel {
request,
inference_notes,
} = materialized;
let FitRequest::BernoulliMarginalSlope(request) = request else {
panic!("expected Bernoulli marginal-slope request");
};
let kept: Vec<&str> = request
.spec
.marginalspec
.linear_terms
.iter()
.map(|term| term.name.as_str())
.collect();
assert_eq!(kept, vec!["x"]);
assert_eq!(request.spec.marginalspec.smooth_terms.len(), 1);
assert_eq!(request.spec.logslopespec.smooth_terms.len(), 1);
assert!(
inference_notes
.iter()
.any(|note| note.contains("constant_spline_col")),
"materialization should report the removed redundant scalar term; notes={inference_notes:?}"
);
}
#[test]
fn materialize_bernoulli_marginal_slope_prunes_binary_outcome_style_scalar_alias() {
let data = Dataset {
headers: vec![
"event".to_string(),
"sex".to_string(),
"entry_age_z".to_string(),
"current_age_ns_1".to_string(),
"current_age_ns_2".to_string(),
"current_age_ns_3".to_string(),
"current_age_ns_4".to_string(),
"prs_z".to_string(),
"PC1".to_string(),
"PC2".to_string(),
"PC3".to_string(),
],
values: Array2::from_shape_vec(
(8, 11),
vec![
0.0, 0.0, -1.4, 1.0, -0.6, 0.36, -0.216, -1.3, -1.0, 0.2, 0.7, 1.0, 1.0, -0.9, 1.0,
-0.2, 0.04, -0.008, -0.8, -0.5, -0.3, 0.5, 0.0, 0.0, -0.5, 1.0, 0.1, 0.01, 0.001,
-0.2, 0.1, 0.4, -0.2, 1.0, 1.0, -0.1, 1.0, 0.4, 0.16, 0.064, 0.3, 0.7, -0.6, 0.3,
0.0, 0.0, 0.3, 1.0, 0.7, 0.49, 0.343, 0.8, 1.2, 0.9, 0.0, 1.0, 1.0, 0.7, 1.0, 1.0,
1.0, 1.0, 1.2, 1.6, -0.8, -0.4, 0.0, 0.0, 1.1, 1.0, 1.3, 1.69, 2.197, 1.6, -1.4,
0.8, -0.9, 1.0, 1.0, 1.5, 1.0, 1.6, 2.56, 4.096, 2.0, 0.3, -1.1, 0.6,
],
)
.expect("binary-outcome-style BMS scalar-alias test data shape"),
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "event".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "sex".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "entry_age_z".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "current_age_ns_1".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "current_age_ns_2".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "current_age_ns_3".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "current_age_ns_4".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "prs_z".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC1".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC2".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "PC3".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![
ColumnKindTag::Binary,
ColumnKindTag::Binary,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
],
};
let config = FitConfig {
logslope_formula: Some("matern(PC1, PC2, PC3, centers=3)".to_string()),
z_column: Some("prs_z".to_string()),
..FitConfig::default()
};
let materialized = materialize(
"event ~ matern(PC1, PC2, PC3, centers=3) + sex + entry_age_z + current_age_ns_1 + current_age_ns_2 + current_age_ns_3 + current_age_ns_4",
&data,
&config,
)
.expect("BMS materialization should prune the local-column-3 scalar alias");
let FitRequest::BernoulliMarginalSlope(request) = materialized.request else {
panic!("expected Bernoulli marginal-slope request");
};
let kept: Vec<&str> = request
.spec
.marginalspec
.linear_terms
.iter()
.map(|term| term.name.as_str())
.collect();
assert_eq!(
kept,
vec![
"sex",
"entry_age_z",
"current_age_ns_2",
"current_age_ns_3",
"current_age_ns_4"
]
);
assert_eq!(request.spec.marginalspec.smooth_terms.len(), 1);
assert_eq!(request.spec.logslopespec.smooth_terms.len(), 1);
assert!(
materialized
.inference_notes
.iter()
.any(|note| note.contains("current_age_ns_1")),
"materialization should report the removed binary-outcome-style scalar alias; notes={:?}",
materialized.inference_notes
);
}
#[test]
fn materialize_bernoulli_marginal_slope_rejects_constrained_redundant_scalar_term() {
let data = Dataset {
headers: vec![
"event".to_string(),
"x".to_string(),
"constant_spline_col".to_string(),
"prs_z".to_string(),
],
values: Array2::from_shape_vec(
(6, 4),
vec![
0.0, -2.0, 1.0, -1.2, 1.0, -1.0, 1.0, -0.4, 0.0, 0.0, 1.0, 0.1, 1.0, 1.0, 1.0, 0.5,
0.0, 2.0, 1.0, 1.1, 1.0, 3.0, 1.0, 1.7,
],
)
.expect("BMS constrained redundant scalar test data shape"),
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "event".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "x".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "constant_spline_col".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "prs_z".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![
ColumnKindTag::Binary,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
],
};
let config = FitConfig {
logslope_formula: Some("1".to_string()),
z_column: Some("prs_z".to_string()),
..FitConfig::default()
};
let err = match materialize(
"event ~ x + linear(constant_spline_col, min=0.0)",
&data,
&config,
) {
Ok(_) => panic!("constrained duplicate scalar term must be rejected, not pruned"),
Err(err) => err,
};
let msg = err.to_string();
assert!(
msg.contains("constrained linear term 'constant_spline_col' is redundant"),
"error should explain that the constrained duplicate scalar cannot be pruned: {msg}"
);
}
#[test]
fn bernoulli_marginal_slope_prune_rejects_penalized_redundant_scalar_term() {
let data = Dataset {
headers: vec!["event".to_string(), "constant_spline_col".to_string()],
values: Array2::from_shape_vec((4, 2), vec![0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0])
.expect("BMS penalized redundant scalar test data shape"),
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "event".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "constant_spline_col".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![ColumnKindTag::Binary, ColumnKindTag::Continuous],
};
let mut spec = TermCollectionSpec {
linear_terms: vec![LinearTermSpec {
name: "constant_spline_col".to_string(),
feature_col: 1,
feature_cols: vec![1],
categorical_levels: vec![],
double_penalty: true,
coefficient_geometry: crate::smooth::LinearCoefficientGeometry::Unconstrained,
coefficient_min: None,
coefficient_max: None,
}],
random_effect_terms: vec![],
smooth_terms: vec![],
};
let mut notes = Vec::new();
let err = prune_unidentified_linear_terms_for_marginal_slope(
&mut spec,
&data,
"test BMS formula",
&mut notes,
)
.expect_err("explicitly penalized duplicate scalar term must be rejected");
let msg = err.to_string();
assert!(
msg.contains("explicitly penalized linear term 'constant_spline_col' is redundant"),
"error should reject ridge-identification of duplicate scalar directions: {msg}"
);
assert_eq!(spec.linear_terms.len(), 1);
assert!(notes.is_empty());
}
#[test]
fn materialize_bernoulli_marginal_slope_names_constant_z_column() {
let data = Dataset {
headers: vec!["event".to_string(), "bmi".to_string(), "prs_z".to_string()],
values: Array2::from_shape_vec(
(4, 3),
vec![
0.0, 22.0, -0.58, 1.0, 24.0, -0.58, 0.0, 27.0, -0.58, 1.0, 29.0, -0.58,
],
)
.expect("constant z test data shape"),
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "event".to_string(),
kind: ColumnKindTag::Binary,
levels: vec![],
},
SchemaColumn {
name: "bmi".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "prs_z".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![
ColumnKindTag::Binary,
ColumnKindTag::Continuous,
ColumnKindTag::Continuous,
],
};
let config = FitConfig {
logslope_formula: Some("1".to_string()),
z_column: Some("prs_z".to_string()),
..FitConfig::default()
};
let err = match materialize("event ~ bmi", &data, &config) {
Ok(_) => panic!("constant z_column should be rejected before BMS integration"),
Err(err) => err,
};
let msg = err.to_string();
assert!(
msg.contains("z_column 'prs_z' has zero weighted variance"),
"error should name the constant z_column and diagnose weighted variance: {msg}"
);
assert!(
msg.contains("all 4 values ~= -0.580000"),
"error should summarize the observed constant value: {msg}"
);
assert!(
msg.contains("weighted_sd=0.000000e0") && msg.contains("n=4"),
"error should report weighted_sd and n: {msg}"
);
assert!(
msg.contains(
"bernoulli-marginal-slope cannot identify a covariate-varying slope from a constant score"
),
"error should explain why the input is invalid: {msg}"
);
assert!(
!msg.contains("requires z with positive finite weighted standard deviation"),
"workflow should surface the input-style message instead of the generic BMS normalization error: {msg}"
);
}
#[test]
fn linkwiggle_defaults_are_consistent_across_formula_and_runtime() {
let parsed = parse_linkwiggle_formulaspec(&Default::default(), "linkwiggle()")
.expect("default linkwiggle should parse");
let formula_default = default_linkwiggle_formulaspec();
let runtime_default = DeviationBlockConfig::default();
assert_eq!(parsed.degree, formula_default.degree);
assert_eq!(
parsed.num_internal_knots,
formula_default.num_internal_knots
);
assert_eq!(parsed.penalty_orders, formula_default.penalty_orders);
assert_eq!(parsed.double_penalty, formula_default.double_penalty);
assert_eq!(runtime_default.degree, formula_default.degree);
assert_eq!(
runtime_default.num_internal_knots,
formula_default.num_internal_knots
);
assert_eq!(
runtime_default.penalty_orders,
formula_default.penalty_orders
);
assert_eq!(
runtime_default.double_penalty,
formula_default.double_penalty
);
}
#[test]
fn survival_marginal_slope_accepts_explicit_probit_link() {
let data = workflow_test_dataset();
let config = FitConfig {
survival_likelihood: "marginal-slope".to_string(),
logslope_formula: Some("1".to_string()),
z_column: Some("z".to_string()),
..FitConfig::default()
};
let ok = materialize(
"Surv(age_entry, age_exit, event) ~ bmi + link(type=probit)",
&data,
&config,
);
assert!(ok.is_ok(), "explicit probit should be accepted");
let err = match materialize(
"Surv(age_entry, age_exit, event) ~ bmi + link(type=logit)",
&data,
&config,
) {
Ok(_) => panic!("non-probit link should be rejected"),
Err(err) => err,
};
assert!(err.to_string().contains("only link(type=probit)"));
}
#[test]
fn high_dimensional_duchon_default_power_is_admissible() {
let dim = 16;
let power = minimum_duchon_power_for_operator_penalties(dim, DuchonNullspaceOrder::Zero, 2);
assert!(2 * (1 + power) > dim + 2);
}
#[test]
fn survival_location_scale_wiggle_rejects_unsupported_inverse_link() {
let data = workflow_test_dataset();
let materialized = materialize(
"Surv(age_entry, age_exit, event) ~ bmi + linkwiggle(degree=4, internal_knots=3, penalty_order=\"1\")",
&data,
&FitConfig::default(),
)
.expect("workflow materialization should succeed");
let MaterializedModel { request, .. } = materialized;
let FitRequest::SurvivalLocationScale(mut request) = request else {
panic!("expected survival location-scale request");
};
request.spec.inverse_link = InverseLink::Sas(
state_from_sasspec(SasLinkSpec {
initial_epsilon: 0.1,
initial_log_delta: 0.0,
})
.expect("valid SAS state"),
);
request.optimize_inverse_link = false;
let err = match fit_survival_location_scale_model(request) {
Ok(_) => panic!("survival link wiggle should reject unsupported inverse links"),
Err(e) => e,
};
assert!(err.contains("survival link wiggle"));
assert!(err.contains("does not support"));
}
#[test]
fn survival_inverse_link_result_requires_convergence() {
let err = recover_converged_survival_inverse_link(
workflow_test_outer_result(false, Array1::from_vec(vec![0.1, -0.2])),
"survival inverse-link optimization (SAS, dim=2)",
|_| Some(InverseLink::Standard(StandardLink::Logit)),
)
.expect_err("non-converged inverse-link search should fail");
assert!(err.contains("did not converge"));
assert!(err.contains("final_objective"));
}
#[test]
fn survival_inverse_link_result_requires_recoverable_state() {
let err = recover_converged_survival_inverse_link(
workflow_test_outer_result(true, Array1::from_vec(vec![9.0, 8.0])),
"survival inverse-link optimization (mixture, dim=2)",
|_| None,
)
.expect_err("unrecoverable inverse-link state should fail");
assert!(err.contains("produced an invalid inverse-link state"));
assert!(err.contains("9.0"));
}
#[test]
fn timewiggle_rejected_in_nonsurvival_main_formula() {
let data = workflow_test_dataset();
let err = materialize(
"bmi ~ z + timewiggle(internal_knots=4)",
&data,
&FitConfig::default(),
)
.err()
.expect("timewiggle in a non-survival formula must be rejected, not silently ignored");
let msg = err.to_string();
assert!(
msg.contains("timewiggle(...)") && msg.contains("survival"),
"error should explain timewiggle is survival-only, got: {msg}"
);
}
#[test]
fn survmodel_rejected_in_nonsurvival_main_formula() {
let data = workflow_test_dataset();
let err = materialize(
"bmi ~ z + survmodel(spec=net)",
&data,
&FitConfig::default(),
)
.err()
.expect("survmodel in a non-survival formula must be rejected, not silently ignored");
let msg = err.to_string();
assert!(
msg.contains("survmodel(...)") && msg.contains("survival"),
"error should explain survmodel is survival-only, got: {msg}"
);
}
#[test]
fn linkwiggle_rejected_for_nonbinomial_response() {
let data = workflow_test_dataset();
let err = materialize(
"bmi ~ z + linkwiggle(internal_knots=4)",
&data,
&FitConfig::default(),
)
.err()
.expect("linkwiggle on a non-binomial response must be rejected, not silently ignored");
let msg = err.to_string();
assert!(
msg.contains("linkwiggle(...)") && msg.contains("binomial"),
"error should explain linkwiggle is binomial-only, got: {msg}"
);
}
#[test]
fn flexible_link_rejected_for_nonbinomial_standard_response() {
let data = workflow_test_dataset();
let mut config = FitConfig::default();
config.family = Some("poisson".to_string());
config.link = Some("flexible(log)".to_string());
let err = materialize("bmi ~ z", &data, &config)
.err()
.expect("flexible(log) on a Poisson response must be rejected, not silently ignored");
let msg = err.to_string();
assert!(
msg.contains("flexible(...)") && msg.contains("non-binomial"),
"error should explain flexible links are binomial-only, got: {msg}"
);
}
#[test]
fn formula_flexible_link_rejected_for_nonbinomial_standard_response() {
let data = workflow_test_dataset();
let mut config = FitConfig::default();
config.family = Some("poisson".to_string());
let err = materialize("bmi ~ z + link(type=flexible(log))", &data, &config)
.err()
.expect("formula flexible(log) on a Poisson response must be rejected");
let msg = err.to_string();
assert!(
msg.contains("flexible(...)") && msg.contains("non-binomial"),
"error should explain flexible links are binomial-only, got: {msg}"
);
}
#[test]
fn flexible_link_flag_rejected_for_nonbinomial_standard_response() {
let data = workflow_test_dataset();
let mut config = FitConfig::default();
config.family = Some("gaussian".to_string());
config.flexible_link = true;
let err = materialize("bmi ~ z", &data, &config)
.err()
.expect("flexible_link=True on a Gaussian response must be rejected");
let msg = err.to_string();
assert!(
msg.contains("flexible(...)") && msg.contains("non-binomial"),
"error should explain flexible links are binomial-only, got: {msg}"
);
}
#[test]
fn flexible_link_rejected_for_nonbinomial_location_scale_response() {
let data = workflow_test_dataset();
let mut config = FitConfig::default();
config.link = Some("flexible(identity)".to_string());
config.noise_formula = Some("1".to_string());
let err = materialize("bmi ~ z", &data, &config)
.err()
.expect("flexible(identity) on a Gaussian location-scale response must be rejected");
let msg = err.to_string();
assert!(
msg.contains("flexible(...)") && msg.contains("non-binomial"),
"error should explain flexible links are binomial-only, got: {msg}"
);
}
#[test]
fn timewiggle_still_accepted_in_survival_formula() {
let data = load_survival_dataset();
let result = materialize(
"Surv(entry, exit, event) ~ x + timewiggle(internal_knots=2)",
&data,
&FitConfig::default(),
);
if let Err(err) = result {
let msg = err.to_string();
assert!(
!(msg.contains("timewiggle(...)") && msg.contains("meaningless")),
"survival timewiggle wrongly rejected by the non-survival guard: {msg}"
);
}
}
fn gaussian_location_scale_dataset() -> Dataset {
let n = 48usize;
let mut records: Vec<csv::StringRecord> = Vec::with_capacity(n);
for i in 0..n {
let x = -2.0 + 4.0 * (i as f64) / ((n - 1) as f64);
let y = 0.7 * x + 0.3 * (1.3 * x).sin();
records.push(csv::StringRecord::from(vec![
format!("{y:.17e}"),
format!("{x:.17e}"),
]));
}
crate::inference::data::encode_recordswith_inferred_schema(
vec!["y".to_string(), "x".to_string()],
records,
)
.expect("encode gaussian location-scale dataset")
}
fn binomial_location_scale_dataset() -> Dataset {
let n = 60usize;
let mut records: Vec<csv::StringRecord> = Vec::with_capacity(n);
for i in 0..n {
let x = -2.0 + 4.0 * (i as f64) / ((n - 1) as f64);
let y = if i % 2 == 0 { 1.0 } else { 0.0 };
records.push(csv::StringRecord::from(vec![
format!("{y:.17e}"),
format!("{x:.17e}"),
]));
}
crate::inference::data::encode_recordswith_inferred_schema(
vec!["y".to_string(), "x".to_string()],
records,
)
.expect("encode binomial location-scale dataset")
}
fn small_wiggle_cfg() -> LinkWiggleConfig {
LinkWiggleConfig {
degree: 3,
num_internal_knots: 3,
penalty_orders: vec![2],
double_penalty: false,
}
}
fn assert_block_states_match(label: &str, lhs: &UnifiedFitResult, rhs: &UnifiedFitResult) {
assert_eq!(
lhs.block_states.len(),
rhs.block_states.len(),
"{label}: block count mismatch (engine {} vs reference {})",
lhs.block_states.len(),
rhs.block_states.len()
);
for (i, (a, b)) in lhs
.block_states
.iter()
.zip(rhs.block_states.iter())
.enumerate()
{
assert_eq!(
a.beta.len(),
b.beta.len(),
"{label}: block {i} coefficient length mismatch"
);
for (j, (&av, &bv)) in a.beta.iter().zip(b.beta.iter()).enumerate() {
assert!(
(av - bv).abs() <= 1e-12 * (1.0 + bv.abs()),
"{label}: block {i} coef {j} diverged: engine {av:.17e} vs reference {bv:.17e}"
);
}
}
}
fn assert_beta_link_wiggle_match(
label: &str,
engine: &Option<Vec<f64>>,
reference: &Option<Vec<f64>>,
) {
match (engine, reference) {
(Some(e), Some(r)) => {
assert_eq!(
e.len(),
r.len(),
"{label}: beta_link_wiggle length mismatch (engine {} vs reference {})",
e.len(),
r.len()
);
for (j, (&ev, &rv)) in e.iter().zip(r.iter()).enumerate() {
assert!(
(ev - rv).abs() <= 1e-12 * (1.0 + rv.abs()),
"{label}: beta_link_wiggle coef {j} diverged: \
engine {ev:.17e} vs reference {rv:.17e}"
);
}
}
(None, None) => {}
(e, r) => panic!(
"{label}: beta_link_wiggle presence mismatch (engine is_some={}, reference is_some={})",
e.is_some(),
r.is_some()
),
}
}
#[test]
fn gaussian_location_scale_engine_matches_reference_flow() {
let data = gaussian_location_scale_dataset();
let config = FitConfig {
family: Some("gaussian".to_string()),
noise_formula: Some("1".to_string()),
..FitConfig::default()
};
let materialized =
materialize("y ~ x", &data, &config).expect("gaussian location-scale materialization");
let FitRequest::GaussianLocationScale(request) = materialized.request else {
panic!("expected a Gaussian location-scale request");
};
let GaussianLocationScaleFitRequest {
data: req_data,
spec,
options,
kappa_options,
..
} = request;
let engine_plain = fit_gaussian_location_scale_model(GaussianLocationScaleFitRequest {
data: req_data,
spec: spec.clone(),
wiggle: None,
options: options.clone(),
kappa_options: kappa_options.clone(),
})
.expect("engine gaussian no-wiggle fit");
let reference_plain =
fit_gaussian_location_scale_terms(req_data, spec.clone(), &options, &kappa_options)
.expect("reference gaussian no-wiggle fit");
assert_block_states_match(
"gaussian/no-wiggle",
&engine_plain.fit.fit,
&reference_plain.fit,
);
assert!(engine_plain.wiggle_knots.is_none());
assert!(engine_plain.wiggle_degree.is_none());
assert!(engine_plain.beta_link_wiggle.is_none());
let wiggle_cfg = small_wiggle_cfg();
let engine_wiggle = fit_gaussian_location_scale_model(GaussianLocationScaleFitRequest {
data: req_data,
spec: spec.clone(),
wiggle: Some(wiggle_cfg.clone()),
options: options.clone(),
kappa_options: kappa_options.clone(),
})
.expect("engine gaussian wiggle fit");
let ref_pilot =
fit_gaussian_location_scale_terms(req_data, spec.clone(), &options, &kappa_options)
.expect("reference gaussian pilot");
let ref_basis = select_gaussian_location_scale_link_wiggle_basis_from_pilot(
&ref_pilot,
&WiggleBlockConfig {
degree: wiggle_cfg.degree,
num_internal_knots: wiggle_cfg.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle_cfg.double_penalty,
},
&wiggle_cfg.penalty_orders,
)
.expect("reference gaussian wiggle basis selection");
let ref_solved = fit_gaussian_location_scale_terms_with_selected_wiggle(
req_data,
spec.clone(),
ref_basis,
&options,
&kappa_options,
)
.expect("reference gaussian wiggle refit");
assert_block_states_match(
"gaussian/wiggle",
&engine_wiggle.fit.fit,
&ref_solved.fit.fit,
);
assert_eq!(
engine_wiggle.wiggle_degree,
Some(ref_solved.wiggle_degree),
"gaussian wiggle degree must match the reference refit"
);
let engine_knots = engine_wiggle
.wiggle_knots
.as_ref()
.expect("engine gaussian wiggle knots present");
assert_eq!(
engine_knots.len(),
ref_solved.wiggle_knots.len(),
"gaussian wiggle knot count must match the reference refit"
);
for (k, (&ek, &rk)) in engine_knots
.iter()
.zip(ref_solved.wiggle_knots.iter())
.enumerate()
{
assert!(
(ek - rk).abs() <= 1e-12 * (1.0 + rk.abs()),
"gaussian wiggle knot {k} diverged: engine {ek:.17e} vs reference {rk:.17e}"
);
}
let ref_beta_link_wiggle = ref_solved
.fit
.fit
.block_states
.get(2)
.map(|b| b.beta.to_vec());
assert_beta_link_wiggle_match(
"gaussian",
&engine_wiggle.beta_link_wiggle,
&ref_beta_link_wiggle,
);
assert!(
engine_wiggle.beta_link_wiggle.is_some(),
"a wiggle refit must populate beta_link_wiggle (block 2 present)"
);
}
#[test]
fn binomial_location_scale_engine_matches_reference_flow() {
let data = binomial_location_scale_dataset();
let config = FitConfig {
family: Some("binomial".to_string()),
noise_formula: Some("1".to_string()),
..FitConfig::default()
};
let materialized =
materialize("y ~ x", &data, &config).expect("binomial location-scale materialization");
let FitRequest::BinomialLocationScale(request) = materialized.request else {
panic!("expected a binomial location-scale request");
};
let BinomialLocationScaleFitRequest {
data: req_data,
spec,
options,
kappa_options,
..
} = request;
let engine_plain = fit_binomial_location_scale_model(BinomialLocationScaleFitRequest {
data: req_data,
spec: spec.clone(),
wiggle: None,
options: options.clone(),
kappa_options: kappa_options.clone(),
})
.expect("engine binomial no-wiggle fit");
let reference_plain =
fit_binomial_location_scale_terms(req_data, spec.clone(), &options, &kappa_options)
.expect("reference binomial no-wiggle fit");
assert_block_states_match(
"binomial/no-wiggle",
&engine_plain.fit.fit,
&reference_plain.fit,
);
assert!(engine_plain.wiggle_knots.is_none());
assert!(engine_plain.wiggle_degree.is_none());
assert!(engine_plain.beta_link_wiggle.is_none());
let wiggle_cfg = small_wiggle_cfg();
let engine_wiggle = fit_binomial_location_scale_model(BinomialLocationScaleFitRequest {
data: req_data,
spec: spec.clone(),
wiggle: Some(wiggle_cfg.clone()),
options: options.clone(),
kappa_options: kappa_options.clone(),
})
.expect("engine binomial wiggle fit");
require_inverse_link_supports_joint_wiggle(
&spec.link_kind,
"binomial location-scale link wiggle",
)
.expect("logit link supports joint wiggle");
let ref_pilot =
fit_binomial_location_scale_terms(req_data, spec.clone(), &options, &kappa_options)
.expect("reference binomial pilot");
let ref_basis = select_binomial_location_scale_link_wiggle_basis_from_pilot(
&ref_pilot,
&WiggleBlockConfig {
degree: wiggle_cfg.degree,
num_internal_knots: wiggle_cfg.num_internal_knots,
penalty_order: 2,
double_penalty: wiggle_cfg.double_penalty,
},
&wiggle_cfg.penalty_orders,
)
.expect("reference binomial wiggle basis selection");
let ref_solved = fit_binomial_location_scale_terms_with_selected_wiggle(
req_data,
spec.clone(),
ref_basis,
&options,
&kappa_options,
)
.expect("reference binomial wiggle refit");
assert_block_states_match(
"binomial/wiggle",
&engine_wiggle.fit.fit,
&ref_solved.fit.fit,
);
assert_eq!(
engine_wiggle.wiggle_degree,
Some(ref_solved.wiggle_degree),
"binomial wiggle degree must match the reference refit"
);
let engine_knots = engine_wiggle
.wiggle_knots
.as_ref()
.expect("engine binomial wiggle knots present");
assert_eq!(
engine_knots.len(),
ref_solved.wiggle_knots.len(),
"binomial wiggle knot count must match the reference refit"
);
for (k, (&ek, &rk)) in engine_knots
.iter()
.zip(ref_solved.wiggle_knots.iter())
.enumerate()
{
assert!(
(ek - rk).abs() <= 1e-12 * (1.0 + rk.abs()),
"binomial wiggle knot {k} diverged: engine {ek:.17e} vs reference {rk:.17e}"
);
}
let ref_beta_link_wiggle = ref_solved
.fit
.fit
.block_states
.get(2)
.map(|b| b.beta.to_vec());
assert_beta_link_wiggle_match(
"binomial",
&engine_wiggle.beta_link_wiggle,
&ref_beta_link_wiggle,
);
assert!(
engine_wiggle.beta_link_wiggle.is_some(),
"a wiggle refit must populate beta_link_wiggle (block 2 present)"
);
}
#[test]
fn resolve_family_accepts_mgcv_parenthesized_family_link_syntax() {
use crate::solver::fit_orchestration::resolve_family;
use crate::types::{
InverseLink, LinkFunction, ResponseColumnKind, ResponseFamily, StandardLink,
};
let y = ndarray::array![0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
for raw in [
"binomial(logit)",
"Binomial(Logit)",
"binomial(LOGIT)",
"binomial( logit )",
"binomial_logit",
"binomial-logit",
] {
let spec = resolve_family(
Some(raw),
None,
None,
y.view(),
ResponseColumnKind::Numeric,
"y",
)
.unwrap_or_else(|err| panic!("resolve_family({raw:?}) failed: {err}"));
assert!(
matches!(spec.response, ResponseFamily::Binomial),
"{raw}: expected Binomial response"
);
assert_eq!(
spec.link.link_function(),
LinkFunction::Logit,
"{raw}: expected logit link"
);
}
let probit = resolve_family(
Some("binomial(probit)"),
None,
None,
y.view(),
ResponseColumnKind::Numeric,
"y",
)
.expect("binomial(probit) resolves");
assert_eq!(probit.link.link_function(), LinkFunction::Probit);
let cloglog = resolve_family(
Some("Binomial(CLogLog)"),
None,
None,
y.view(),
ResponseColumnKind::Numeric,
"y",
)
.expect("binomial(cloglog) resolves");
assert_eq!(cloglog.link.link_function(), LinkFunction::CLogLog);
let nb = resolve_family(
Some("negative_binomial(log)"),
None,
None,
ndarray::array![0.0, 1.0, 2.0, 3.0].view(),
ResponseColumnKind::Numeric,
"y",
)
.expect("negative_binomial(log) resolves");
assert!(matches!(
nb.response,
ResponseFamily::NegativeBinomial { .. }
));
assert!(matches!(nb.link, InverseLink::Standard(StandardLink::Log)));
}
fn monotone_parity_dataset() -> Dataset {
let n = 60usize;
let mut flat = Vec::with_capacity(n * 2);
for i in 0..n {
let x = (i as f64 + 0.5) / n as f64; let y = x.sqrt() + 0.01 * ((7 * i) % 5) as f64 / 5.0;
flat.push(x);
flat.push(y);
}
Dataset {
headers: vec!["x".to_string(), "y".to_string()],
values: Array2::from_shape_vec((n, 2), flat).expect("monotone parity data shape"),
schema: DataSchema {
columns: vec![
SchemaColumn {
name: "x".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
SchemaColumn {
name: "y".to_string(),
kind: ColumnKindTag::Continuous,
levels: vec![],
},
],
},
column_kinds: vec![ColumnKindTag::Continuous, ColumnKindTag::Continuous],
}
}
#[test]
fn issue_1196_cli_and_formula_standard_fit_options_match() {
let data = monotone_parity_dataset();
let config = FitConfig::default();
let formula = "y ~ s(x, shape=monotone_increasing)";
let materialized =
materialize(formula, &data, &config).expect("formula path materializes the monotone fit");
let FitRequest::Standard(request) = materialized.request else {
panic!("expected a standard request for a Gaussian shape-constrained smooth");
};
let cli_options = crate::solver::fit_orchestration::canonical_standard_fit_options(
&config,
crate::solver::fit_orchestration::StandardFitOptionsInputs {
firth_bias_reduction: config.firth,
..Default::default()
},
);
assert_eq!(
format!("{:#?}", request.options),
format!("{cli_options:#?}"),
"CLI and formula entry points must build identical standard FitOptions (#1196)"
);
assert!(
request.options.skip_rho_posterior_inference,
"canonical formula/CLI policy skips the live-rho posterior path"
);
assert_eq!(
request.options.tol, 1e-10,
"canonical outer-REML tolerance is the gam#893 value, not the stale CLI 1e-6"
);
}
#[test]
fn issue_1191_shape_constrained_monotone_fits_through_shared_driver() {
let data = monotone_parity_dataset();
let config = FitConfig::default();
let formula = "y ~ s(x, shape=monotone_increasing)";
let materialized = materialize(formula, &data, &config)
.expect("monotone shape-constrained smooth materializes");
let result = fit_model(materialized.request)
.expect("monotone shape-constrained smooth fits through the shared driver (#1191)");
let FitResult::Standard(standard) = result else {
panic!("expected a standard fit result");
};
let beta = standard
.fit
.block_by_role(crate::estimate::BlockRole::Mean)
.expect("fitted mean block")
.beta
.clone();
assert!(
beta.iter().all(|b| b.is_finite()),
"fitted coefficients must be finite (no ALO-NaN seed rejection)"
);
}