Skip to main content

gam_solve/
rho_uncertainty.rs

1//! PSIS diagnostic for marginal smoothing-parameter uncertainty.
2//!
3//! The diagnostic treats the exact outer Hessian at `rho_hat` as a Laplace
4//! proposal, evaluates the exact profiled criterion at a deterministic finite
5//! set of proposal draws, and fits a GPD tail to the resulting importance
6//! weights. A large `k_hat` is evidence that fixed-`rho` REML/LAML intervals are
7//! inadequate for this fit; a small `k_hat` is only absence of heavy-tail
8//! evidence at the probed points, not a proof about unprobed tails. A criterion
9//! closure can agree with the Gaussian proposal at every deterministic draw and
10//! still have catastrophic heavier tails elsewhere.
11
12use crate::psis::{MIN_TAIL_COUNT, pareto_smooth_weights};
13use ndarray::{Array1, Array2};
14
15const DEFAULT_SAMPLE_COUNT: usize = 32;
16const MAX_AUTO_RHO_DIM: usize = 4;
17const MAX_AUTO_WORK_UNITS: usize = 2_000_000;
18
19#[derive(Clone, Debug, PartialEq)]
20pub struct RhoUncertaintyDiagnostic {
21    pub k_hat: Option<f64>,
22    pub n_evaluations: usize,
23    pub status: RhoUncertaintyStatus,
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub enum RhoUncertaintyStatus {
28    NoEvidenceOfHeavyTails,
29    HeavyTailsDetected { k_hat: f64 },
30    Skipped { reason: String },
31}
32
33#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
34pub struct RhoUncertaintyProblemSize {
35    pub n_obs: Option<usize>,
36    pub p_coefficients: Option<usize>,
37}
38
39#[derive(Clone, Copy, Debug, PartialEq, Eq)]
40pub struct RhoUncertaintyCostGate {
41    pub sample_count: usize,
42    pub problem_size: RhoUncertaintyProblemSize,
43}
44
45impl Default for RhoUncertaintyCostGate {
46    fn default() -> Self {
47        Self {
48            sample_count: DEFAULT_SAMPLE_COUNT,
49            problem_size: RhoUncertaintyProblemSize::default(),
50        }
51    }
52}
53
54impl RhoUncertaintyDiagnostic {
55    pub fn skipped(reason: impl Into<String>, n_evaluations: usize) -> Self {
56        Self {
57            k_hat: None,
58            n_evaluations,
59            status: RhoUncertaintyStatus::Skipped {
60                reason: reason.into(),
61            },
62        }
63    }
64}
65
66pub fn cost_gate_allows(rho_dim: usize, gate: RhoUncertaintyCostGate) -> Result<usize, String> {
67    if rho_dim == 0 {
68        return Err("no smoothing parameters".to_string());
69    }
70    if rho_dim > MAX_AUTO_RHO_DIM {
71        return Err(format!(
72            "rho dimension {rho_dim} exceeds automatic PSIS diagnostic limit {MAX_AUTO_RHO_DIM}"
73        ));
74    }
75    let sample_count = gate.sample_count.max(2 * MIN_TAIL_COUNT);
76    let n = gate.problem_size.n_obs.unwrap_or(1);
77    let p = gate.problem_size.p_coefficients.unwrap_or(1);
78    let work_units = sample_count
79        .saturating_add(1)
80        .saturating_mul(rho_dim.max(1))
81        .saturating_mul(n.max(1))
82        .saturating_mul(p.max(1));
83    if work_units > MAX_AUTO_WORK_UNITS {
84        return Err(format!(
85            "estimated diagnostic cost {work_units} work units exceeds automatic limit {MAX_AUTO_WORK_UNITS} \
86             (M={sample_count}, K={rho_dim}, n={}, p={})",
87            gate.problem_size.n_obs.unwrap_or(0),
88            gate.problem_size.p_coefficients.unwrap_or(0),
89        ));
90    }
91    Ok(sample_count)
92}
93
94pub fn rho_uncertainty_diagnostic<F>(
95    rho_hat: &Array1<f64>,
96    outer_hessian_rho: &Array2<f64>,
97    gate: RhoUncertaintyCostGate,
98    mut criterion: F,
99) -> RhoUncertaintyDiagnostic
100where
101    F: FnMut(&Array1<f64>) -> Option<f64>,
102{
103    let rho_dim = rho_hat.len();
104    let sample_count = match cost_gate_allows(rho_dim, gate) {
105        Ok(sample_count) => sample_count,
106        Err(reason) => return RhoUncertaintyDiagnostic::skipped(reason, 0),
107    };
108    if outer_hessian_rho.nrows() != rho_dim || outer_hessian_rho.ncols() != rho_dim {
109        return RhoUncertaintyDiagnostic::skipped(
110            format!(
111                "outer rho Hessian shape {}x{} does not match K={rho_dim}",
112                outer_hessian_rho.nrows(),
113                outer_hessian_rho.ncols()
114            ),
115            0,
116        );
117    }
118    let Some(cost_hat) = criterion(rho_hat).filter(|value| value.is_finite()) else {
119        return RhoUncertaintyDiagnostic::skipped("criterion was not finite at rho_hat", 1);
120    };
121    let Some(proposal_factor) = proposal_factor_from_hessian(outer_hessian_rho) else {
122        return RhoUncertaintyDiagnostic::skipped("outer rho Hessian was not positive definite", 1);
123    };
124
125    let mut rng = DeterministicNormal::new(seed_from_problem(rho_hat, gate.problem_size));
126    let mut log_weights = Vec::with_capacity(sample_count);
127    let mut n_evaluations = 1usize;
128    for _draw in 0..sample_count {
129        let z = Array1::from_iter((0..rho_dim).map(|coord| rng.normal(coord)));
130        let rho = rho_hat + &proposal_factor.dot(&z);
131        let half_norm_sq = 0.5 * z.iter().map(|value| value * value).sum::<f64>();
132        let log_weight = match criterion(&rho) {
133            Some(cost) if cost.is_finite() => -cost + cost_hat + half_norm_sq,
134            _ => f64::NEG_INFINITY,
135        };
136        log_weights.push(log_weight);
137        n_evaluations = n_evaluations.saturating_add(1);
138    }
139
140    let max_log_weight = log_weights
141        .iter()
142        .copied()
143        .filter(|value| value.is_finite())
144        .fold(f64::NEG_INFINITY, f64::max);
145    if !max_log_weight.is_finite() {
146        return RhoUncertaintyDiagnostic::skipped(
147            "all proposal draws had non-finite criterion values",
148            n_evaluations,
149        );
150    }
151    let weights: Vec<f64> = log_weights
152        .iter()
153        .map(|&value| {
154            if value.is_finite() {
155                (value - max_log_weight).exp()
156            } else {
157                0.0
158            }
159        })
160        .collect();
161    let (min_weight, max_weight) = weights
162        .iter()
163        .copied()
164        .fold((f64::INFINITY, f64::NEG_INFINITY), |(min_w, max_w), w| {
165            (min_w.min(w), max_w.max(w))
166        });
167    if max_weight.is_finite()
168        && min_weight.is_finite()
169        && max_weight > 0.0
170        && (max_weight - min_weight) <= 1e-12 * max_weight.max(1.0)
171    {
172        return RhoUncertaintyDiagnostic {
173            k_hat: Some(0.0),
174            n_evaluations,
175            status: RhoUncertaintyStatus::NoEvidenceOfHeavyTails,
176        };
177    }
178    let Some(psis) = pareto_smooth_weights(&weights) else {
179        return RhoUncertaintyDiagnostic::skipped(
180            "PSIS tail fit failed for rho-importance weights",
181            n_evaluations,
182        );
183    };
184    let k_hat = psis.k_hat;
185    let status = if k_hat < 0.5 {
186        RhoUncertaintyStatus::NoEvidenceOfHeavyTails
187    } else {
188        RhoUncertaintyStatus::HeavyTailsDetected { k_hat }
189    };
190    RhoUncertaintyDiagnostic {
191        k_hat: Some(k_hat),
192        n_evaluations,
193        status,
194    }
195}
196
197fn proposal_factor_from_hessian(hessian: &Array2<f64>) -> Option<Array2<f64>> {
198    let chol = cholesky_lower(hessian)?;
199    let n = chol.nrows();
200    let mut inverse_lower = Array2::<f64>::zeros((n, n));
201    for col in 0..n {
202        for row in 0..n {
203            let mut acc = if row == col { 1.0 } else { 0.0 };
204            for k in 0..row {
205                acc -= chol[[row, k]] * inverse_lower[[k, col]];
206            }
207            let diagonal = chol[[row, row]];
208            if !(diagonal.is_finite() && diagonal > 0.0) {
209                return None;
210            }
211            inverse_lower[[row, col]] = acc / diagonal;
212        }
213    }
214    let mut factor = Array2::<f64>::zeros((n, n));
215    for row in 0..n {
216        for col in 0..n {
217            factor[[row, col]] = inverse_lower[[col, row]];
218        }
219    }
220    Some(factor)
221}
222
223fn cholesky_lower(matrix: &Array2<f64>) -> Option<Array2<f64>> {
224    let n = matrix.nrows();
225    if n == 0 || matrix.ncols() != n || matrix.iter().any(|value| !value.is_finite()) {
226        return None;
227    }
228    let mut lower = Array2::<f64>::zeros((n, n));
229    for row in 0..n {
230        for col in 0..=row {
231            let mut acc = matrix[[row, col]];
232            for k in 0..col {
233                acc -= lower[[row, k]] * lower[[col, k]];
234            }
235            if row == col {
236                if !(acc.is_finite() && acc > 0.0) {
237                    return None;
238                }
239                lower[[row, col]] = acc.sqrt();
240            } else {
241                let diagonal = lower[[col, col]];
242                if !(diagonal.is_finite() && diagonal > 0.0) {
243                    return None;
244                }
245                lower[[row, col]] = acc / diagonal;
246            }
247        }
248    }
249    Some(lower)
250}
251
252fn seed_from_problem(rho_hat: &Array1<f64>, size: RhoUncertaintyProblemSize) -> u64 {
253    let mut state = 0xcbf2_9ce4_8422_2325_u64;
254    mix_u64(&mut state, size.n_obs.unwrap_or(0) as u64);
255    mix_u64(&mut state, size.p_coefficients.unwrap_or(0) as u64);
256    mix_u64(&mut state, rho_hat.len() as u64);
257    for value in rho_hat {
258        mix_u64(&mut state, value.to_bits());
259    }
260    state
261}
262
263fn mix_u64(state: &mut u64, value: u64) {
264    for byte in value.to_le_bytes() {
265        *state ^= u64::from(byte);
266        *state = state.wrapping_mul(0x0000_0100_0000_01b3);
267    }
268}
269
270struct DeterministicNormal {
271    state: u64,
272    spare: Option<f64>,
273}
274
275impl DeterministicNormal {
276    fn new(seed: u64) -> Self {
277        Self {
278            state: seed,
279            spare: None,
280        }
281    }
282
283    fn normal(&mut self, coord: usize) -> f64 {
284        if let Some(value) = self.spare.take() {
285            return value;
286        }
287        mix_u64(&mut self.state, coord as u64);
288        let u1 = self.uniform().max(1e-300);
289        let u2 = self.uniform();
290        let radius = (-2.0 * u1.ln()).sqrt();
291        let angle = 2.0 * std::f64::consts::PI * u2;
292        self.spare = Some(radius * angle.sin());
293        radius * angle.cos()
294    }
295
296    fn uniform(&mut self) -> f64 {
297        self.state = self.state.wrapping_add(0x9e37_79b9_7f4a_7c15);
298        let mut z = self.state;
299        z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
300        z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
301        z ^= z >> 31;
302        ((z >> 11) as f64 + 0.5) / (1_u64 << 53) as f64
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use ndarray::array;
310
311    fn gaussian_criterion(
312        rho_hat: Array1<f64>,
313        hessian: Array2<f64>,
314    ) -> impl FnMut(&Array1<f64>) -> Option<f64> {
315        move |rho: &Array1<f64>| {
316            let delta = rho - &rho_hat;
317            Some(0.5 * delta.dot(&hessian.dot(&delta)))
318        }
319    }
320
321    #[test]
322    fn near_gaussian_target_has_no_heavy_tail_evidence_at_probe_points() {
323        let rho_hat = array![0.2, -0.3];
324        let hessian = array![[2.5, 0.2], [0.2, 1.8]];
325        let diagnostic = rho_uncertainty_diagnostic(
326            &rho_hat,
327            &hessian,
328            RhoUncertaintyCostGate {
329                sample_count: 32,
330                problem_size: RhoUncertaintyProblemSize {
331                    n_obs: Some(40),
332                    p_coefficients: Some(8),
333                },
334            },
335            gaussian_criterion(rho_hat.clone(), hessian.clone()),
336        );
337        assert!(
338            matches!(
339                diagnostic.status,
340                RhoUncertaintyStatus::NoEvidenceOfHeavyTails
341            ),
342            "near-Gaussian rho posterior should not show heavy-tail evidence at the probe \
343             points, got {diagnostic:?}"
344        );
345        assert!(
346            diagnostic.k_hat.expect("k_hat") < 0.5,
347            "near-Gaussian target should have k_hat below 0.5"
348        );
349    }
350
351    #[test]
352    fn weak_identification_orders_above_gaussian_case() {
353        let rho_hat = array![0.0];
354        let hessian = array![[5.0]];
355        let gate = RhoUncertaintyCostGate {
356            sample_count: 64,
357            problem_size: RhoUncertaintyProblemSize {
358                n_obs: Some(12),
359                p_coefficients: Some(4),
360            },
361        };
362        let gaussian = rho_uncertainty_diagnostic(
363            &rho_hat,
364            &hessian,
365            gate,
366            gaussian_criterion(rho_hat.clone(), hessian.clone()),
367        );
368        let weak = rho_uncertainty_diagnostic(&rho_hat, &hessian, gate, |rho| {
369            Some((1.0 + rho[0] * rho[0]).ln())
370        });
371        assert!(
372            weak.k_hat.expect("weak k_hat") > gaussian.k_hat.expect("gaussian k_hat"),
373            "weak rho identification should increase k_hat: weak={weak:?} gaussian={gaussian:?}"
374        );
375    }
376
377    #[test]
378    fn diagnostic_is_bit_deterministic() {
379        let rho_hat = array![0.7];
380        let hessian = array![[1.4]];
381        let gate = RhoUncertaintyCostGate {
382            sample_count: 32,
383            problem_size: RhoUncertaintyProblemSize {
384                n_obs: Some(80),
385                p_coefficients: Some(9),
386            },
387        };
388        let a = rho_uncertainty_diagnostic(
389            &rho_hat,
390            &hessian,
391            gate,
392            gaussian_criterion(rho_hat.clone(), hessian.clone()),
393        );
394        let b = rho_uncertainty_diagnostic(
395            &rho_hat,
396            &hessian,
397            gate,
398            gaussian_criterion(rho_hat.clone(), hessian.clone()),
399        );
400        assert_eq!(a, b);
401    }
402
403    #[test]
404    fn cost_gate_skips_large_problem() {
405        let rho_hat = array![0.0, 0.0, 0.0, 0.0, 0.0];
406        let hessian = array![
407            [1.0, 0.0, 0.0, 0.0, 0.0],
408            [0.0, 1.0, 0.0, 0.0, 0.0],
409            [0.0, 0.0, 1.0, 0.0, 0.0],
410            [0.0, 0.0, 0.0, 1.0, 0.0],
411            [0.0, 0.0, 0.0, 0.0, 1.0],
412        ];
413        let diagnostic = rho_uncertainty_diagnostic(
414            &rho_hat,
415            &hessian,
416            RhoUncertaintyCostGate::default(),
417            |_| Some(0.0),
418        );
419        assert!(matches!(
420            diagnostic.status,
421            RhoUncertaintyStatus::Skipped { .. }
422        ));
423    }
424}