use super::*;
pub(crate) fn run_diagnose(args: DiagnoseArgs) -> Result<(), String> {
let mut progress = gam::visualizer::VisualizerSession::new(true);
progress.start_workflow("Diagnose", 5);
progress.set_stage("diagnose", "loading fitted model");
let model = SavedModel::load_from_path(&args.model)?;
progress.advance_workflow(1);
let parsed = parse_formula(&model.formula)?;
if model.predict_model_class() != PredictModelClass::Standard {
return Err(format!(
"diagnose --alo is not yet supported for {model_class:?} models; \
only standard GAM fits are covered. \
(You can still inspect the model with `gam report <model>`.)",
model_class = model.predict_model_class()
));
}
if model.spline_scan.is_some() {
return Err(
"diagnose --alo needs the dense leave-one-out leverage, which a \
spline-scan model does not retain (it stores only the per-knot \
posterior of the exact O(n) smoother). Use `gam report <model>` \
for its fitted quantities, or refit with double_penalty=true to \
obtain the dense fit ALO requires."
.to_string(),
);
}
progress.set_stage("diagnose", "loading diagnostic dataset");
let ds = load_datasetwith_model_schema_for_diagnostics(&args.data, &model)?;
require_dataset_rows("diagnose", &args.data, ds.values.nrows())?;
progress.advance_workflow(2);
let col_map = ds.column_map();
let training_headers = model.training_headers.as_ref();
let family = model.likelihood();
let y_col = resolve_role_col(&col_map, &parsed.response, "response")?;
let y = ds.values.column(y_col).to_owned();
let spec = resolve_termspec_for_prediction(
&model.resolved_termspec,
training_headers,
&col_map,
"resolved_termspec",
)?;
progress.set_stage("diagnose", "building diagnostic design");
let design = build_term_collection_design(ds.values.view(), &spec)
.map_err(|e| format!("failed to build term collection design: {e}"))?;
progress.advance_workflow(3);
let link = family.link_function();
let weights = Array1::ones(ds.values.nrows());
let (offset, _noise_offset) = report_offset_for(&model, &ds, &col_map)?;
let alo = if let Some((unified, geom)) = model
.unified()
.and_then(|u| u.geometry.as_ref().map(|g| (u, g)))
{
progress.set_stage("diagnose", "computing alo from saved geometry");
let fit_saved = fit_result_from_saved_model_for_prediction(&model)?;
let eta = &design.design.dot(&fit_saved.beta) + &offset;
let alo_design_dense = design.design.to_dense();
let phi = geometry_alo_phi(unified, link);
let input =
gam::alo::AloInput::from_geometry(geom, &alo_design_dense, &eta, &offset, link, phi);
progress.advance_workflow(4);
gam::alo::compute_alo_from_input(&input)
.map_err(|e| format!("compute_alo_from_input (geometry path) failed: {e}"))?
} else {
progress.set_stage("diagnose", "refitting model for alo");
let fit_options = FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: false,
skip_rho_posterior_inference: false,
max_iter: 80,
tol: 1e-6,
nullspace_dims: design.nullspace_dims.clone(),
linear_constraints: design.linear_constraints.clone(),
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: Some(1e-6),
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
persist_warm_start_disk: false,
};
let alo_result = match alo_refit_route_for_termspec(&spec) {
AloRefitRoute::UnifiedTermCollection => {
let fitted = fit_term_collection_forspec(
ds.values.view(),
y.view(),
weights.view(),
offset.view(),
&spec,
family,
&fit_options,
)
.map_err(|e| {
format!("fit_term_collection_forspec failed during diagnose refit: {e}")
})?;
let eta = &fitted.design.design.dot(&fitted.fit.beta) + &offset;
let dense_alo_design = fitted.design.design.to_dense();
let phi = geometry_alo_phi(&fitted.fit, link);
gam::alo::compute_alo_diagnostics_from_unified(
&fitted.fit,
&dense_alo_design,
&eta,
&offset,
link,
phi,
)
.map_err(|e| {
format!(
"compute_alo_diagnostics_from_unified failed during diagnose refit: {e}"
)
})
}
AloRefitRoute::StandardGam => {
let fit = fit_gam(
design.design.clone(),
y.view(),
weights.view(),
offset.view(),
&design.penalties,
family,
&fit_options,
)
.map_err(|e| format!("fit_gam failed during diagnose refit: {e}"))?;
compute_alo_diagnostics_from_fit(&fit, y.view(), link)
.map_err(|e| format!("compute_alo_diagnostics_from_fit failed: {e}"))
}
};
progress.advance_workflow(4);
alo_result?
};
let mut rows: Vec<(usize, f64, f64, f64)> = (0..alo.leverage.len())
.map(|i| (i, alo.leverage[i], alo.eta_tilde[i], alo.se_sandwich[i]))
.collect();
rows.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.set_content_arrangement(ContentArrangement::Dynamic)
.set_header(vec!["row", "leverage", "eta_tilde", "alo_se"]);
for (row, lev, eta, se) in rows.into_iter().take(12) {
table.add_row(Row::from(vec![
Cell::new(row),
Cell::new(format!("{lev:.4}")),
Cell::new(format!("{eta:.6}")),
Cell::new(format!("{se:.6}")),
]));
}
cli_out!("ALO diagnostics (top leverage rows):");
cli_out!("{table}");
if let Some(unified) = model.unified() {
let fit_saved = fit_result_from_saved_model_for_prediction(&model)?;
let eta_hat = &design.design.dot(&fit_saved.beta) + &offset;
let comparison = gam::model_comparison::model_comparison_from_unified(
unified,
y.view(),
eta_hat.view(),
weights.view(),
Some(&alo),
);
let mut summary = Table::new();
summary
.load_preset(UTF8_FULL)
.set_content_arrangement(ContentArrangement::Dynamic)
.set_header(vec!["criterion", "value"]);
summary.add_row(Row::from(vec![
Cell::new("edf (conditional)"),
Cell::new(format!("{:.4}", comparison.edf.conditional)),
]));
summary.add_row(Row::from(vec![
Cell::new("edf (corrected, WPS)"),
Cell::new(format!("{:.4}", comparison.edf.corrected)),
]));
summary.add_row(Row::from(vec![
Cell::new("rho-uncertainty df"),
Cell::new(format!("{:.4}", comparison.edf.rho_uncertainty_df())),
]));
summary.add_row(Row::from(vec![
Cell::new("AIC (conditional)"),
Cell::new(format!("{:.4}", comparison.aic_conditional)),
]));
summary.add_row(Row::from(vec![
Cell::new("AIC (corrected)"),
Cell::new(format!("{:.4}", comparison.aic_corrected)),
]));
if let Some(loo) = comparison.loo.as_ref() {
summary.add_row(Row::from(vec![
Cell::new("PSIS-LOO elpd"),
Cell::new(format!("{:.4} (se {:.4})", loo.elpd, loo.se)),
]));
summary.add_row(Row::from(vec![
Cell::new("PSIS k_hat (max)"),
Cell::new(format!("{:.3} ({} unreliable)", loo.k_hat_max, loo.n_k_bad)),
]));
}
cli_out!("Model comparison (corrected AIC + PSIS-LOO):");
cli_out!("{summary}");
}
progress.advance_workflow(5);
progress.finish_progress("diagnostics complete");
Ok(())
}