use gam::basis::create_thin_plate_spline_basis_with_knot_count;
use gam::construction::canonicalize_penalty_spec;
use gam::estimate::PenaltySpec;
use gam::estimate::{FitOptions, fit_gam};
use gam::predict::predict_gam;
use gam::smooth::BlockwisePenalty;
use gam::types::LikelihoodFamily;
use ndarray::{Array1, Array2};
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use rand_distr::{Distribution, Normal};
#[test]
fn thin_plate_fit_gam_gaussian_fast_integration() {
let nx = 12usize;
let ny = 10usize;
let n = nx * ny;
let mut data = Array2::<f64>::zeros((n, 2));
let mut y = Array1::<f64>::zeros(n);
let mut row = 0usize;
for ix in 0..nx {
for iy in 0..ny {
let x1 = ix as f64 / (nx as f64 - 1.0);
let x2 = iy as f64 / (ny as f64 - 1.0);
data[[row, 0]] = x1;
data[[row, 1]] = x2;
y[row] = (std::f64::consts::PI * x1).sin() + 0.5 * (x2 - 0.5).powi(2);
row += 1;
}
}
let basis = create_thin_plate_spline_basis_with_knot_count(data.view(), 24).expect("TPS basis");
let tps = basis.0;
let weights = Array1::ones(n);
let offset = Array1::zeros(n);
let s_list = vec![
BlockwisePenalty::new(0..tps.basis.ncols(), tps.penalty_bending.clone()),
BlockwisePenalty::new(0..tps.basis.ncols(), tps.penalty_ridge.clone()),
];
let fit = fit_gam(
tps.basis.clone(),
y.view(),
weights.view(),
offset.view(),
&s_list,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 40,
tol: 1e-6,
nullspace_dims: vec![0, 0],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("fit_gam with TPS should succeed");
assert_eq!(fit.lambdas.len(), 2);
assert_eq!(fit.beta.len(), tps.basis.ncols());
assert!(fit.edf_total().is_some_and(f64::is_finite));
let pred = predict_gam(
tps.basis.clone(),
fit.beta.view(),
offset.view(),
LikelihoodFamily::GaussianIdentity,
)
.expect("predict_gam should succeed");
let mse = (&pred.mean - &y)
.mapv(|v| v * v)
.mean()
.unwrap_or(f64::INFINITY);
assert!(
mse < 5e-2,
"TPS integration fit is too inaccurate, mse={mse:.6e}"
);
}
#[test]
fn thin_plate_fit_gam_gaussian_simulated_train_test() {
let n_train = 900usize;
let n_test = 300usize;
let mse_test_bound = 0.12_f64;
let mut rng = StdRng::seed_from_u64(20260226);
let noise = Normal::new(0.0, 0.10).expect("normal params must be valid");
let mut x_train = Array2::<f64>::zeros((n_train, 2));
let mut y_train = Array1::<f64>::zeros(n_train);
let mut y_train_true = Array1::<f64>::zeros(n_train);
for i in 0..n_train {
let x1 = rng.random_range(-1.0..1.0);
let x2 = rng.random_range(-1.0..1.0);
x_train[[i, 0]] = x1;
x_train[[i, 1]] = x2;
let r2 = (x1 - 0.25).powi(2) + (x2 + 0.15).powi(2);
let f = 1.1 * (-r2 / (2.0 * 0.38 * 0.38)).exp() + 0.45 * (std::f64::consts::PI * x1).sin()
- 0.30 * x2
+ 0.25 * x1 * x2;
y_train_true[i] = f;
y_train[i] = f + noise.sample(&mut rng);
}
let (tps_train, knots) =
create_thin_plate_spline_basis_with_knot_count(x_train.view(), 30).expect("TPS basis");
let weights = Array1::ones(n_train);
let offset = Array1::zeros(n_train);
let s_list = vec![
BlockwisePenalty::new(
0..tps_train.basis.ncols(),
tps_train.penalty_bending.clone(),
),
BlockwisePenalty::new(0..tps_train.basis.ncols(), tps_train.penalty_ridge.clone()),
];
let fit = fit_gam(
tps_train.basis.clone(),
y_train.view(),
weights.view(),
offset.view(),
&s_list,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 60,
tol: 1e-6,
nullspace_dims: vec![3, 0],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("fit_gam with TPS should succeed");
assert_eq!(fit.lambdas.len(), 2);
assert_eq!(fit.beta.len(), tps_train.basis.ncols());
assert!(fit.edf_total().is_some_and(f64::is_finite));
let mut x_test = Array2::<f64>::zeros((n_test, 2));
let mut y_test_true = Array1::<f64>::zeros(n_test);
for i in 0..n_test {
let x1 = rng.random_range(-1.0..1.0);
let x2 = rng.random_range(-1.0..1.0);
x_test[[i, 0]] = x1;
x_test[[i, 1]] = x2;
let r2 = (x1 - 0.25).powi(2) + (x2 + 0.15).powi(2);
y_test_true[i] = 1.1 * (-r2 / (2.0 * 0.38 * 0.38)).exp()
+ 0.45 * (std::f64::consts::PI * x1).sin()
- 0.30 * x2
+ 0.25 * x1 * x2;
}
let tps_test = gam::basis::create_thin_plate_spline_basis(x_test.view(), knots.view())
.expect("TPS test basis");
let pred = predict_gam(
tps_test.basis.clone(),
fit.beta.view(),
Array1::zeros(n_test).view(),
LikelihoodFamily::GaussianIdentity,
)
.expect("predict_gam should succeed");
let mse_test = (&pred.mean - &y_test_true)
.mapv(|v| v * v)
.mean()
.unwrap_or(f64::INFINITY);
assert!(
mse_test < mse_test_bound,
"TPS simulated integration test is too inaccurate: \
mse_test={mse_test:.6e}, bound={mse_test_bound:.6e}"
);
let mean_truth = y_test_true.mean().unwrap_or(0.0);
let mse_mean_baseline = (&y_test_true - mean_truth)
.mapv(|v| v * v)
.mean()
.unwrap_or(f64::INFINITY);
assert!(
mse_test < 0.5 * mse_mean_baseline,
"TPS fit must beat mean-only baseline by ≥50%: mse_test={mse_test:.6e}, \
mse_mean_baseline={mse_mean_baseline:.6e}"
);
let probe_xs = [
[0.25_f64, -0.15_f64], [-0.95, 0.95], [0.95, -0.95], ];
let probe = Array2::from_shape_vec((3, 2), probe_xs.iter().flatten().copied().collect())
.expect("probe matrix shape");
let probe_basis = gam::basis::create_thin_plate_spline_basis(probe.view(), knots.view())
.expect("TPS probe basis");
let probe_pred = predict_gam(
probe_basis.basis.clone(),
fit.beta.view(),
Array1::zeros(3).view(),
LikelihoodFamily::GaussianIdentity,
)
.expect("predict_gam should succeed for probe");
let center_pred = probe_pred.mean[0];
let corner_pred_a = probe_pred.mean[1];
let corner_pred_b = probe_pred.mean[2];
assert!(
center_pred > corner_pred_a && center_pred > corner_pred_b,
"TPS fit failed to learn the bump structure: center={center_pred:.4}, \
corner_a={corner_pred_a:.4}, corner_b={corner_pred_b:.4}"
);
}
#[test]
fn thin_plate_fit_gam_gaussian_3d_simulated_train_test() {
let n_train = 600usize;
let n_test = 250usize;
let mut rng = StdRng::seed_from_u64(20260301);
let noise = Normal::new(0.0, 0.08).expect("normal params must be valid");
let mut x_train = Array2::<f64>::zeros((n_train, 3));
let mut y_train = Array1::<f64>::zeros(n_train);
let mut y_train_true = Array1::<f64>::zeros(n_train);
for i in 0..n_train {
let x1 = rng.random_range(-1.0..1.0);
let x2 = rng.random_range(-1.0..1.0);
let x3 = rng.random_range(-1.0..1.0);
x_train[[i, 0]] = x1;
x_train[[i, 1]] = x2;
x_train[[i, 2]] = x3;
let r2 = (x1 - 0.3).powi(2) + (x2 + 0.2).powi(2) + (x3 - 0.1).powi(2);
let f = 0.9 * (-r2 / (2.0 * 0.45 * 0.45)).exp() + 0.35 * (std::f64::consts::PI * x1).sin()
- 0.25 * x2 * x3
+ 0.15 * x3;
y_train_true[i] = f;
y_train[i] = f + noise.sample(&mut rng);
}
let (tps_train, knots) =
create_thin_plate_spline_basis_with_knot_count(x_train.view(), 36).expect("3D TPS basis");
let s_list = vec![
BlockwisePenalty::new(
0..tps_train.basis.ncols(),
tps_train.penalty_bending.clone(),
),
BlockwisePenalty::new(0..tps_train.basis.ncols(), tps_train.penalty_ridge.clone()),
];
let p = s_list[0].local.nrows();
let cp = canonicalize_penalty_spec(
&PenaltySpec::Dense(s_list[0].local.clone()),
p,
0,
"3D TPS test",
)
.expect("canonicalize penalty")
.expect("penalty should have positive rank");
assert!(
cp.rank() > 0,
"3D TPS bending penalty root should have positive rank"
);
let root_reconstructed = cp.root.t().dot(&cp.root);
let reconstruction_max_abs = (&root_reconstructed - &cp.local)
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(
reconstruction_max_abs < 1e-8,
"3D TPS bending root reconstruction mismatch: max_abs={reconstruction_max_abs:.3e}"
);
let weights = Array1::ones(n_train);
let offset = Array1::zeros(n_train);
let fit = fit_gam(
tps_train.basis.clone(),
y_train.view(),
weights.view(),
offset.view(),
&s_list,
LikelihoodFamily::GaussianIdentity,
&FitOptions {
latent_cloglog: None,
mixture_link: None,
optimize_mixture: false,
sas_link: None,
optimize_sas: false,
compute_inference: true,
max_iter: 80,
tol: 1e-6,
nullspace_dims: vec![4, 0],
linear_constraints: None,
firth_bias_reduction: false,
adaptive_regularization: None,
penalty_shrinkage_floor: None,
rho_prior: Default::default(),
kronecker_penalty_system: None,
kronecker_factored: None,
},
)
.expect("fit_gam with 3D TPS should succeed");
assert_eq!(fit.lambdas.len(), 2);
assert!(fit.edf_total().is_some_and(f64::is_finite));
assert!(fit.edf_total().unwrap_or(0.0) > 1.0);
let mut x_test = Array2::<f64>::zeros((n_test, 3));
let mut y_test_true = Array1::<f64>::zeros(n_test);
for i in 0..n_test {
let x1 = rng.random_range(-1.0..1.0);
let x2 = rng.random_range(-1.0..1.0);
let x3 = rng.random_range(-1.0..1.0);
x_test[[i, 0]] = x1;
x_test[[i, 1]] = x2;
x_test[[i, 2]] = x3;
let r2 = (x1 - 0.3).powi(2) + (x2 + 0.2).powi(2) + (x3 - 0.1).powi(2);
y_test_true[i] = 0.9 * (-r2 / (2.0 * 0.45 * 0.45)).exp()
+ 0.35 * (std::f64::consts::PI * x1).sin()
- 0.25 * x2 * x3
+ 0.15 * x3;
}
let tps_test = gam::basis::create_thin_plate_spline_basis(x_test.view(), knots.view())
.expect("3D test basis");
let pred = predict_gam(
tps_test.basis.clone(),
fit.beta.view(),
Array1::zeros(n_test).view(),
LikelihoodFamily::GaussianIdentity,
)
.expect("3D predict_gam should succeed");
let mse_test = (&pred.mean - &y_test_true)
.mapv(|v| v * v)
.mean()
.unwrap_or(f64::INFINITY);
assert!(
mse_test < 0.09,
"3D TPS simulated integration test is too inaccurate: mse_test={mse_test:.6e}"
);
}