Skip to main content

gam_terms/smooth/
input_standardization.rs

1//! Per-axis input standardization and length-scale compensation helpers for
2//! the spatial smooth arm.
3//!
4//! Pure numeric helpers relocated verbatim from `smooth.rs` (issue #780
5//! decomposition): per-column variance scales, in-place standardization, the
6//! geometric-mean scale, and the kernel length-scale compensation maps that
7//! keep the Matérn/Duchon/thin-plate range in original coordinates after
8//! standardization. No behavior change — bodies are byte-identical and the
9//! parent re-imports each name so every call site is unchanged.
10
11use ndarray::{Array2, ArrayView2};
12
13/// Compute per-column standard deviations for spatial inputs.
14///
15/// Standardizing each covariate axis to unit spread makes the
16/// Matérn/Duchon/thin-plate kernel — and the `ψ = log κ = −log ℓ` REML
17/// length-scale optimizer that refines it — operate in scale-free coordinates,
18/// so the fit is invariant to an affine covariate rescale `x → a·x + b`. This
19/// matters in **one dimension too** (issue #1215): a 1-D `s(x, bs="tp")` whose
20/// kernel ran in raw covariate units seeded and bounded its `ψ`-optimizer off
21/// the raw magnitude, landing in a scale-dependent basin (a clean bimodal step
22/// across `|a| ⋛ 1`). Standardizing the single axis the same way as the d > 1
23/// axes removes that magnitude from the optimizer's view, so the selected `ψ̂`
24/// (hence the fitted curve) is scale-invariant. The frozen scale is replayed at
25/// predict, so original-unit queries map onto the same standardized geometry.
26///
27/// Returns `None` only when there is no axis or too few rows to estimate a
28/// spread, or when the caller already supplies frozen scales (prediction path).
29pub fn compute_spatial_input_scales(x: ArrayView2<'_, f64>) -> Option<Vec<f64>> {
30    let d = x.ncols();
31    if d == 0 {
32        return None;
33    }
34    let n = x.nrows() as f64;
35    if n < 2.0 {
36        return None;
37    }
38    let mut scales = Vec::with_capacity(d);
39    for j in 0..d {
40        let col = x.column(j);
41        let mean = col.sum() / n;
42        let var = col.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / (n - 1.0);
43        scales.push(var.sqrt().max(1e-12));
44    }
45    Some(scales)
46}
47
48/// Apply per-column standardization to a data matrix using precomputed scales.
49pub fn apply_input_standardization(x: &mut Array2<f64>, scales: &[f64]) {
50    for j in 0..x.ncols() {
51        let inv = 1.0 / scales[j];
52        x.column_mut(j).mapv_inplace(|v| v * inv);
53    }
54}
55
56/// Geometric mean of strictly positive scales: `(∏ s_a)^(1/d)`.
57///
58/// Computed via log-sum-divide to avoid overflow / underflow when d is large
59/// or when individual scales are small. The Matérn / Duchon / thin-plate
60/// auto-standardization paths use this to compensate the user's
61/// `length_scale` so the kernel range remains expressed in *original* data
62/// coordinates after per-axis division by σ_a:
63///
64///   ‖x_std − c_std‖ / L_eff with L_eff = L_user / σ_geom
65///
66/// matches `‖x − c‖ / L_user` exactly for uniform σ_a (= σ_geom) and reduces
67/// to the natural anisotropic-Mahalanobis preconditioning when σ_a vary —
68/// the convention σ_geom = (∏σ_a)^(1/d) preserves the kernel volume scale.
69fn geometric_mean_scale(scales: &[f64]) -> f64 {
70    if scales.is_empty() {
71        return 1.0;
72    }
73    let log_mean: f64 = scales.iter().map(|&s| s.ln()).sum::<f64>() / scales.len() as f64;
74    log_mean.exp()
75}
76
77pub fn compensate_length_scale_for_standardization(
78    length_scale: f64,
79    scales: &[f64],
80) -> f64 {
81    let sigma_geom = geometric_mean_scale(scales);
82    if sigma_geom > 0.0 && sigma_geom.is_finite() {
83        length_scale / sigma_geom
84    } else {
85        length_scale
86    }
87}
88
89pub fn compensate_optional_length_scale_for_standardization(
90    length_scale: Option<f64>,
91    scales: &[f64],
92) -> Option<f64> {
93    length_scale.map(|l| compensate_length_scale_for_standardization(l, scales))
94}