Skip to main content

gam_solve/
psis.rs

1//! Pareto-smoothed importance-sampling utilities.
2//!
3//! The implementation is intentionally self-contained: it estimates the
4//! generalized-Pareto tail shape `k` from the largest positive weights and
5//! replaces only that empirical tail by monotone GPD expected quantiles.  The
6//! returned `k_hat` has the usual GPD tail interpretation: values near zero
7//! indicate light tails, `k > 0.5` indicates that the fitted tail has infinite
8//! variance, and larger values mark increasingly unstable upper tails. Consumers
9//! decide whether that tail is a draw-wise PSIS reliability diagnostic or another
10//! influence diagnostic based on what the supplied weights represent.
11//!
12//! The shape is recovered with the Zhang–Stephens (2009) empirical-Bayes
13//! profile estimator — the same GPD tail fit used by `loo`/ArviZ for draw-wise
14//! PSIS diagnostics.
15//! Crucially it is consistent across the entire `k ∈ (−∞, ∞)` range, including
16//! the dangerous `k ≥ 0.5` regime where the GPD variance is infinite and the
17//! older method-of-moments form `k = ½(1 − μ²/Var)` is structurally capped
18//! below `0.5` and so cannot fire a heavy-tail gate.
19
20#[derive(Debug, Clone)]
21pub struct PsisResult {
22    pub smoothed: Vec<f64>,
23    pub k_hat: f64,
24    pub tail_start: usize,
25    pub tail_count: usize,
26}
27
28pub const MIN_TAIL_COUNT: usize = 5;
29const MAX_TAIL_FRACTION: f64 = 0.2;
30
31/// Pareto-smooth a non-negative weight vector and report the fitted GPD tail
32/// shape.  Non-tail observations are left bit-identical; only the largest tail
33/// observations are replaced by sorted GPD expected quantiles and then clipped
34/// to be non-decreasing in the original sorted order.
35pub fn pareto_smooth_weights(weights: &[f64]) -> Option<PsisResult> {
36    if weights.len() < MIN_TAIL_COUNT * 2 || weights.iter().any(|w| !w.is_finite() || *w < 0.0) {
37        return None;
38    }
39    let mut indexed: Vec<(usize, f64)> = weights.iter().copied().enumerate().collect();
40    indexed.sort_by(|a, b| a.1.total_cmp(&b.1));
41    let n = indexed.len();
42    let tail_count = ((n as f64).sqrt().ceil() as usize)
43        .max(MIN_TAIL_COUNT)
44        .min(((MAX_TAIL_FRACTION * n as f64).ceil() as usize).max(MIN_TAIL_COUNT))
45        .min(n - 1);
46    let tail_start = n - tail_count;
47    let threshold = indexed[tail_start - 1].1;
48    let excesses: Vec<f64> = indexed[tail_start..]
49        .iter()
50        .map(|(_, w)| (w - threshold).max(0.0))
51        .collect();
52    let (k_hat, sigma_hat) = fit_gpd_moments(&excesses)?;
53    let mut smoothed = weights.to_vec();
54    let mut previous = threshold;
55    for (rank, (original_idx, _)) in indexed[tail_start..].iter().enumerate() {
56        let p = (rank as f64 + 0.5) / tail_count as f64;
57        let q = threshold + gpd_quantile(p, k_hat, sigma_hat).max(0.0);
58        let monotone = q.max(previous);
59        smoothed[*original_idx] = monotone;
60        previous = monotone;
61    }
62    Some(PsisResult {
63        smoothed,
64        k_hat,
65        tail_start,
66        tail_count,
67    })
68}
69
70/// Fit a generalized-Pareto distribution `1 − (1 + k·x/σ)^(−1/k)` to positive
71/// excesses with the Zhang–Stephens (2009) empirical-Bayes profile estimator.
72///
73/// The GPD is reparameterised by the single scalar `b = −k/σ`.  For a fixed
74/// `b` the shape has the closed-form profile MLE
75/// `k(b) = mean_i log(1 − b·xᵢ)`, and the profile log-likelihood (up to an
76/// additive constant) is `ℓ(b) = n·(log(−b/k(b)) − k(b) − 1)`.  Zhang &
77/// Stephens place a diffuse grid of candidate `b` values around `1/x_max`
78/// (the boundary of the admissible region) plus a quartile-scaled spread, then
79/// average `b` under the softmax of `ℓ` — an empirical-Bayes posterior mean
80/// rather than a single maximiser, which is markedly more stable in small
81/// samples.  The shape is read back off the posterior-mean `b` and shrunk
82/// toward `0.5` by a weak `N(0.5, …)`-style prior worth `PRIOR_K`
83/// pseudo-observations (negligible once `n ≫ PRIOR_K`).
84///
85/// Unlike the method-of-moments identity `k = ½(1 − μ²/Var)` — which is bounded
86/// above by `0.5` for every real sample and so can never report the heavy tails
87/// (`k > 0.7`) the diagnostic exists to flag — this estimator is consistent for
88/// `k` up to and beyond `1`, recovering the true shape to a couple percent from
89/// a large exact sample.
90///
91/// Returns `(k_hat, sigma_hat)`, or `None` when there are too few positive
92/// excesses or the data are degenerate (all equal / non-finite).
93pub fn fit_gpd_moments(excesses: &[f64]) -> Option<(f64, f64)> {
94    let mut x: Vec<f64> = excesses
95        .iter()
96        .copied()
97        .filter(|v| v.is_finite() && *v > 0.0)
98        .collect();
99    if x.len() < MIN_TAIL_COUNT {
100        return None;
101    }
102    x.sort_by(f64::total_cmp);
103    let n = x.len();
104    let nf = n as f64;
105
106    // Admissible region is `b < 1/x_max`; the smallest order statistic sets the
107    // quartile scale (Zhang & Stephens use the lower-quartile observation).
108    let x_max = x[n - 1];
109    let q_idx = ((nf / 4.0 + 0.5).floor() as usize)
110        .saturating_sub(1)
111        .min(n - 1);
112    let x_star = x[q_idx];
113    if !(x_max.is_finite() && x_max > 0.0 && x_star.is_finite() && x_star > 0.0) {
114        return None;
115    }
116
117    // Weakly-informative prior: `PRIOR_BS` controls grid spread, `PRIOR_K`
118    // pseudo-observations shrink the final shape toward 0.5 (Zhang–Stephens / loo).
119    const PRIOR_BS: f64 = 3.0;
120    const PRIOR_K: f64 = 10.0;
121    let m_est = 30 + (nf.sqrt() as usize);
122
123    // Profile log-likelihood `ℓ(b)` on the candidate grid (the `n·` factor and
124    // additive constants are retained so the softmax weights match loo/ArviZ).
125    let mut b_grid = Vec::with_capacity(m_est);
126    let mut len_scale = Vec::with_capacity(m_est);
127    let mut max_ls = f64::NEG_INFINITY;
128    for j in 1..=m_est {
129        let b =
130            (1.0 - (m_est as f64 / (j as f64 - 0.5)).sqrt()) / (PRIOR_BS * x_star) + 1.0 / x_max;
131        let k = profile_shape(b, &x);
132        let arg = k.map(|k| -(b / k));
133        let ls = if let (Some(k), Some(arg)) = (k, arg) {
134            if arg.is_finite() && arg > 0.0 {
135                nf * (arg.ln() - k - 1.0)
136            } else {
137                f64::NEG_INFINITY
138            }
139        } else {
140            f64::NEG_INFINITY
141        };
142        if ls > max_ls {
143            max_ls = ls;
144        }
145        b_grid.push(b);
146        len_scale.push(ls);
147    }
148    if !max_ls.is_finite() {
149        return None;
150    }
151
152    // Posterior mean of `b` under the (numerically stable) softmax of `ℓ`.
153    let mut weight_sum = 0.0;
154    let mut b_post = 0.0;
155    for (&b, &ls) in b_grid.iter().zip(len_scale.iter()) {
156        let w = if ls.is_finite() {
157            (ls - max_ls).exp()
158        } else {
159            0.0
160        };
161        weight_sum += w;
162        b_post += w * b;
163    }
164    if !(weight_sum.is_finite() && weight_sum > 0.0) {
165        return None;
166    }
167    b_post /= weight_sum;
168    if !b_post.is_finite() || b_post == 0.0 {
169        return None;
170    }
171
172    let k_raw = profile_shape(b_post, &x)?;
173    // The GPD scale is fixed by the fitted `b_post` and the *profile* shape that
174    // generated it: `σ = −k(b_post)/b_post`. This is positive by construction for
175    // a genuine tail, because the admissible profile shape `k_raw = mean log(1 −
176    // b·xᵢ)` always carries the opposite sign of `b_post` (every `1 − b·xᵢ ∈ (0,
177    // 1)` for `b_post > 0`, so `k_raw < 0`; symmetrically `k_raw > 0` for
178    // `b_post < 0`). Computing σ here, before the shape is shrunk, mirrors
179    // loo/ArviZ and keeps the sign relationship intact: a light tail (`k_raw < 0`)
180    // yields `σ > 0` rather than being rejected.
181    let sigma = -k_raw / b_post;
182    // Shrink only the *reported* shape toward 0.5 by `PRIOR_K` pseudo-observations.
183    // This affects the diagnostic value alone, never the scale's validity, so the
184    // weak toward-0.5 prior can no longer flip `k`'s sign relative to `b_post` and
185    // spuriously reject a perfectly valid light-tail fit.
186    let k = (nf * k_raw + PRIOR_K * 0.5) / (nf + PRIOR_K);
187    if !(k.is_finite() && sigma.is_finite() && sigma > 0.0) {
188        return None;
189    }
190    Some((k, sigma))
191}
192
193/// Closed-form profile shape `k(b) = mean_i log(1 − b·xᵢ)` for fixed `b`.
194///
195/// A candidate is admissible only when every `1 - b*x_i` is finite and strictly
196/// positive. Boundary or out-of-domain candidates are rejected before any log is
197/// evaluated, so they cannot leak NaN/Inf through the profile likelihood.
198#[inline]
199fn profile_shape(b: f64, x: &[f64]) -> Option<f64> {
200    if !b.is_finite() || x.is_empty() {
201        return None;
202    }
203    let mut acc = 0.0_f64;
204    for &xi in x {
205        let arg = 1.0 - b * xi;
206        if !(arg.is_finite() && arg > 0.0) {
207            return None;
208        }
209        acc += arg.ln();
210    }
211    Some(acc / x.len() as f64)
212}
213
214#[inline]
215fn gpd_quantile(p: f64, k: f64, sigma: f64) -> f64 {
216    let survival = (1.0 - p).clamp(1e-12, 1.0);
217    if k.abs() < 1e-8 {
218        -sigma * survival.ln()
219    } else {
220        sigma * (survival.powf(-k) - 1.0) / k
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    fn gpd_sample(u: f64, k: f64, sigma: f64) -> f64 {
229        if k.abs() < 1e-12 {
230            -sigma * (1.0 - u).ln()
231        } else {
232            sigma * ((1.0 - u).powf(-k) - 1.0) / k
233        }
234    }
235
236    #[test]
237    fn psis_k_hat_recovers_known_generalized_pareto_tail() {
238        let k_true = 0.35_f64;
239        let sigma = 1.7_f64;
240        let mut xs = Vec::new();
241        for i in 1..=10_000 {
242            let u = (i as f64 - 0.5) / 10_000.0;
243            xs.push(gpd_sample(u, k_true, sigma));
244        }
245        let (k_hat, sigma_hat) = fit_gpd_moments(&xs).expect("GPD fit should succeed");
246        assert!(
247            (k_hat - k_true).abs() < 0.03,
248            "k_hat={k_hat}, k_true={k_true}"
249        );
250        assert!(
251            (sigma_hat - sigma).abs() < 0.08,
252            "sigma_hat={sigma_hat}, sigma={sigma}"
253        );
254    }
255
256    #[test]
257    fn pareto_smoothing_preserves_nontail_and_reports_heavy_tail() {
258        // A flat block of baseline weights plus a genuine GPD(k=0.7) tail. The
259        // tail is heavy enough to sit in the infinite-variance regime, so a
260        // consistent estimator must report `k_hat > 0.5` for it.
261        let mut w = vec![1.0; 200];
262        for i in 1..=120 {
263            let u = (i as f64 - 0.5) / 120.0;
264            w.push(1.0 + gpd_sample(u, 0.7, 0.5));
265        }
266        let out = pareto_smooth_weights(&w).expect("PSIS should fit a positive tail");
267        assert_eq!(out.smoothed[0], 1.0);
268        assert!(
269            out.k_hat > 0.5,
270            "genuine GPD(k=0.7) tail should be flagged as heavy (infinite variance); got k_hat={}",
271            out.k_hat
272        );
273    }
274
275    /// Regression for #585 from the *shape-recovery* angle: the moment estimator
276    /// `k = ½(1 − μ²/Var)` is structurally `≤ 0.5`, so heavy tails collapsed
277    /// onto an indistinguishable ~0.5. The Zhang–Stephens estimator must instead
278    /// recover the true shape across the dangerous range — and the fitted shapes
279    /// must stay *strictly ordered*, since a diagnostic that cannot separate
280    /// `k=0.7` from `k=1.0` is useless.
281    #[test]
282    fn psis_k_hat_tracks_and_orders_heavy_tails() {
283        let recover = |k_true: f64| -> f64 {
284            let xs: Vec<f64> = (1..=100_000)
285                .map(|i| gpd_sample((i as f64 - 0.5) / 100_000.0, k_true, 1.0))
286                .collect();
287            fit_gpd_moments(&xs).expect("GPD fit should succeed").0
288        };
289        let mut last = f64::NEG_INFINITY;
290        for &k_true in &[0.5_f64, 0.7, 0.85, 1.0, 1.2] {
291            let k_hat = recover(k_true);
292            assert!(
293                (k_hat - k_true).abs() < 0.05,
294                "k_true={k_true}: fitted k_hat={k_hat} not within 0.05"
295            );
296            assert!(
297                k_hat > last,
298                "k_hat must increase with the true shape: {k_hat} !> {last}"
299            );
300            last = k_hat;
301        }
302    }
303
304    /// The estimator must degrade gracefully: too few positive excesses and
305    /// fully degenerate (all-equal) inputs return `None` rather than NaN/`0.5`.
306    #[test]
307    fn psis_gpd_fit_handles_degenerate_inputs() {
308        assert!(
309            fit_gpd_moments(&[1.0, 2.0, 3.0]).is_none(),
310            "fewer than MIN_TAIL_COUNT"
311        );
312        assert!(
313            fit_gpd_moments(&[0.0, -1.0, f64::NAN, 0.0]).is_none(),
314            "no positive finite excesses"
315        );
316        // All-equal positive excesses are degenerate (x_max == x_star). The fit
317        // must not panic or emit NaN, and — most importantly — must never report
318        // a spuriously *heavy* tail; a near-constant block is the lightest tail
319        // there is, so any returned shape must sit well below the 0.5 gate.
320        if let Some((k_hat, sigma_hat)) = fit_gpd_moments(&[2.0; 50]) {
321            assert!(k_hat.is_finite() && sigma_hat.is_finite() && sigma_hat > 0.0);
322            assert!(
323                k_hat < 0.5,
324                "degenerate equal excesses must not be flagged heavy; got k_hat={k_hat}"
325            );
326        }
327    }
328
329    #[test]
330    fn psis_profile_shape_rejects_inadmissible_candidates() {
331        let x = [0.25, 1.0, 2.0];
332        assert!(
333            profile_shape(0.49, &x).is_some(),
334            "b below 1/x_max is admissible for all excesses"
335        );
336        assert!(
337            profile_shape(0.5, &x).is_none(),
338            "b at 1/x_max puts the largest excess on the log boundary"
339        );
340        assert!(
341            profile_shape(0.75, &x).is_none(),
342            "b above 1/x_max makes at least one log argument negative"
343        );
344    }
345}