use gam::estimate::{FitOptions, fit_gam};
use gam::pirls::update_glmvectors_by_family;
use gam::predict::predict_gam;
use gam::probability::normal_cdf;
use gam::smooth::BlockwisePenalty;
use gam::types::{GlmLikelihoodFamily, GlmLikelihoodSpec, LikelihoodFamily};
use ndarray::{Array1, Array2};
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
#[test]
fn probit_fit_and_predict_fast_integration() {
let n = 400usize;
let mut x = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
let mut rng = StdRng::seed_from_u64(7);
for i in 0..n {
let xi = -2.0 + 4.0 * (i as f64) / (n as f64 - 1.0);
let eta = -0.3 + 1.1 * xi;
let p = normal_cdf(eta);
x[[i, 0]] = 1.0;
x[[i, 1]] = xi;
y[i] = if rng.random::<f64>() < p { 1.0 } else { 0.0 };
}
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let mut s = Array2::<f64>::zeros((2, 2));
s[[1, 1]] = 1.0;
let s_list = vec![BlockwisePenalty::new(0..2, s)];
let fit = fit_gam(
x.view(),
y.view(),
weights.view(),
offset.view(),
&s_list,
LikelihoodFamily::BinomialProbit,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 60,
tol: 1e-6,
nullspace_dims: vec![1],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("probit fit should succeed");
assert_eq!(fit.beta.len(), 2);
assert_eq!(fit.lambdas.len(), 1);
assert!(fit.edf_total().is_some_and(f64::is_finite));
let pred = predict_gam(
x.view(),
fit.beta.view(),
offset.view(),
LikelihoodFamily::BinomialProbit,
)
.expect("probit predict should succeed");
assert!(
pred.mean
.iter()
.all(|v: &f64| v.is_finite() && *v >= 0.0 && *v <= 1.0)
);
let brier = (&pred.mean - &y)
.mapv(|v| v * v)
.mean()
.unwrap_or(f64::INFINITY);
assert!(
brier < 0.25,
"unexpectedly poor probit fit: brier={brier:.6e}"
);
}
#[test]
fn probitworkingvectors_are_finite_for_extreme_eta() {
let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 1.0]);
let eta = Array1::from_vec(vec![-100.0, -20.0, 0.0, 20.0, 100.0]);
let w = Array1::ones(y.len());
let mut mu = Array1::zeros(y.len());
let mut weights = Array1::zeros(y.len());
let mut z = Array1::zeros(y.len());
update_glmvectors_by_family(
y.view(),
&eta,
GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::BinomialProbit),
w.view(),
&mut mu,
&mut weights,
&mut z,
)
.expect("probit working-vector update should succeed");
assert!(
mu.iter().all(|v| v.is_finite() && *v >= 0.0 && *v <= 1.0),
"probit mu out of [0,1] or non-finite: mu={mu:?}"
);
assert!(
weights.iter().all(|v| v.is_finite() && *v >= 0.0),
"probit weights non-finite or negative: weights={weights:?}"
);
assert!(
z.iter().all(|v| v.is_finite()),
"probit z non-finite: z={z:?}"
);
assert!(
mu[0] < 1e-6,
"Φ(-100) must collapse to ~0; got mu[0]={}",
mu[0]
);
assert!(mu[1] < 1e-6, "Φ(-20) must be tiny; got mu[1]={}", mu[1]);
assert!(
(mu[2] - 0.5).abs() < 1e-9,
"Φ(0) must equal 0.5 within fp tol; got mu[2]={}",
mu[2]
);
assert!(mu[3] > 1.0 - 1e-6, "Φ(+20) must be ~1; got mu[3]={}", mu[3]);
assert!(
mu[4] > 1.0 - 1e-6,
"Φ(+100) must collapse to ~1; got mu[4]={}",
mu[4]
);
for i in 1..mu.len() {
assert!(
mu[i] >= mu[i - 1] - 1e-15,
"probit mu must be non-decreasing in eta; mu[{i}]={mu_i} < mu[{prev_i}]={mu_prev}",
mu_i = mu[i],
prev_i = i - 1,
mu_prev = mu[i - 1]
);
}
let expected_w_at_zero = 2.0 / std::f64::consts::PI;
assert!(
(weights[2] - expected_w_at_zero).abs() < 1e-10,
"probit IRLS weight at η=0 must equal 2/π = {expected_w_at_zero}; got {}",
weights[2]
);
assert!(
weights[0] < 1e-6,
"probit weight at η=-100 must be ~0; got w[0]={}",
weights[0]
);
assert!(
weights[4] < 1e-6,
"probit weight at η=+100 must be ~0; got w[4]={}",
weights[4]
);
for i in 0..y.len() {
let residual = y[i] - mu[i];
if residual.abs() > 1e-9 {
let pull = z[i] - eta[i];
assert!(
pull * residual >= -1e-12,
"probit working response must pull η toward y on row {i}: \
y={}, mu={}, eta={}, z={}, pull={}, residual={}",
y[i],
mu[i],
eta[i],
z[i],
pull,
residual,
);
}
}
}
#[test]
fn cloglog_fit_and_predict_fast_integration() {
let n = 400usize;
let mut x = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
let mut rng = StdRng::seed_from_u64(17);
for i in 0..n {
let xi = -2.0 + 4.0 * (i as f64) / (n as f64 - 1.0);
let eta = -0.4 + 0.9 * xi;
let z = eta.clamp(-30.0, 30.0);
let p = 1.0 - (-(z.exp())).exp();
x[[i, 0]] = 1.0;
x[[i, 1]] = xi;
y[i] = if rng.random::<f64>() < p { 1.0 } else { 0.0 };
}
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let mut s = Array2::<f64>::zeros((2, 2));
s[[1, 1]] = 1.0;
let s_list = vec![BlockwisePenalty::new(0..2, s)];
let fit = fit_gam(
x.view(),
y.view(),
weights.view(),
offset.view(),
&s_list,
LikelihoodFamily::BinomialCLogLog,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 60,
tol: 1e-6,
nullspace_dims: vec![1],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("cloglog fit should succeed");
let pred = predict_gam(
x.view(),
fit.beta.view(),
offset.view(),
LikelihoodFamily::BinomialCLogLog,
)
.expect("cloglog predict should succeed");
assert!(
pred.mean
.iter()
.all(|v: &f64| v.is_finite() && *v >= 0.0 && *v <= 1.0)
);
}
#[test]
fn cloglogworkingvectors_are_finite_for_extreme_eta() {
let y = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 1.0]);
let eta = Array1::from_vec(vec![-100.0, -20.0, 0.0, 20.0, 100.0]);
let w = Array1::ones(y.len());
let mut mu = Array1::zeros(y.len());
let mut weights = Array1::zeros(y.len());
let mut z = Array1::zeros(y.len());
update_glmvectors_by_family(
y.view(),
&eta,
GlmLikelihoodSpec::canonical(GlmLikelihoodFamily::BinomialCLogLog),
w.view(),
&mut mu,
&mut weights,
&mut z,
)
.expect("cloglog working-vector update should succeed");
assert!(
mu.iter().all(|v| v.is_finite() && *v >= 0.0 && *v <= 1.0),
"cloglog mu out of [0,1] or non-finite: mu={mu:?}"
);
assert!(
weights.iter().all(|v| v.is_finite() && *v >= 0.0),
"cloglog weights non-finite or negative: weights={weights:?}"
);
assert!(
z.iter().all(|v| v.is_finite()),
"cloglog z non-finite: z={z:?}"
);
assert!(
mu[0] < 1e-6,
"cloglog μ(-100) must collapse to ~0; got mu[0]={}",
mu[0]
);
let expected_zero = 1.0 - (-1.0_f64).exp();
assert!(
(mu[2] - expected_zero).abs() < 1e-9,
"cloglog μ(0) must equal 1 - exp(-1) = {expected_zero}; got mu[2]={}",
mu[2]
);
assert!(
mu[3] > 1.0 - 1e-3,
"cloglog μ(+20) must be ~1 (exp(20) saturates the inner exp); got mu[3]={}",
mu[3]
);
assert!(
mu[4] > 1.0 - 1e-6,
"cloglog μ(+100) must collapse to ~1; got mu[4]={}",
mu[4]
);
for i in 1..mu.len() {
assert!(
mu[i] >= mu[i - 1] - 1e-15,
"cloglog mu must be non-decreasing in eta; mu[{i}]={mu_i} < mu[{prev_i}]={mu_prev}",
mu_i = mu[i],
prev_i = i - 1,
mu_prev = mu[i - 1]
);
}
}