1use crate::psis::{MIN_TAIL_COUNT, pareto_smooth_weights};
13use ndarray::{Array1, Array2};
14
15const DEFAULT_SAMPLE_COUNT: usize = 32;
16const MAX_AUTO_RHO_DIM: usize = 4;
17const MAX_AUTO_WORK_UNITS: usize = 2_000_000;
18
19#[derive(Clone, Debug, PartialEq)]
20pub struct RhoUncertaintyDiagnostic {
21 pub k_hat: Option<f64>,
22 pub n_evaluations: usize,
23 pub status: RhoUncertaintyStatus,
24}
25
26#[derive(Clone, Debug, PartialEq)]
27pub enum RhoUncertaintyStatus {
28 NoEvidenceOfHeavyTails,
29 HeavyTailsDetected { k_hat: f64 },
30 Skipped { reason: String },
31}
32
33#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
34pub struct RhoUncertaintyProblemSize {
35 pub n_obs: Option<usize>,
36 pub p_coefficients: Option<usize>,
37}
38
39#[derive(Clone, Copy, Debug, PartialEq, Eq)]
40pub struct RhoUncertaintyCostGate {
41 pub sample_count: usize,
42 pub problem_size: RhoUncertaintyProblemSize,
43}
44
45impl Default for RhoUncertaintyCostGate {
46 fn default() -> Self {
47 Self {
48 sample_count: DEFAULT_SAMPLE_COUNT,
49 problem_size: RhoUncertaintyProblemSize::default(),
50 }
51 }
52}
53
54impl RhoUncertaintyDiagnostic {
55 pub fn skipped(reason: impl Into<String>, n_evaluations: usize) -> Self {
56 Self {
57 k_hat: None,
58 n_evaluations,
59 status: RhoUncertaintyStatus::Skipped {
60 reason: reason.into(),
61 },
62 }
63 }
64}
65
66pub fn cost_gate_allows(rho_dim: usize, gate: RhoUncertaintyCostGate) -> Result<usize, String> {
67 if rho_dim == 0 {
68 return Err("no smoothing parameters".to_string());
69 }
70 if rho_dim > MAX_AUTO_RHO_DIM {
71 return Err(format!(
72 "rho dimension {rho_dim} exceeds automatic PSIS diagnostic limit {MAX_AUTO_RHO_DIM}"
73 ));
74 }
75 let sample_count = gate.sample_count.max(2 * MIN_TAIL_COUNT);
76 let n = gate.problem_size.n_obs.unwrap_or(1);
77 let p = gate.problem_size.p_coefficients.unwrap_or(1);
78 let work_units = sample_count
79 .saturating_add(1)
80 .saturating_mul(rho_dim.max(1))
81 .saturating_mul(n.max(1))
82 .saturating_mul(p.max(1));
83 if work_units > MAX_AUTO_WORK_UNITS {
84 return Err(format!(
85 "estimated diagnostic cost {work_units} work units exceeds automatic limit {MAX_AUTO_WORK_UNITS} \
86 (M={sample_count}, K={rho_dim}, n={}, p={})",
87 gate.problem_size.n_obs.unwrap_or(0),
88 gate.problem_size.p_coefficients.unwrap_or(0),
89 ));
90 }
91 Ok(sample_count)
92}
93
94pub fn rho_uncertainty_diagnostic<F>(
95 rho_hat: &Array1<f64>,
96 outer_hessian_rho: &Array2<f64>,
97 gate: RhoUncertaintyCostGate,
98 mut criterion: F,
99) -> RhoUncertaintyDiagnostic
100where
101 F: FnMut(&Array1<f64>) -> Option<f64>,
102{
103 let rho_dim = rho_hat.len();
104 let sample_count = match cost_gate_allows(rho_dim, gate) {
105 Ok(sample_count) => sample_count,
106 Err(reason) => return RhoUncertaintyDiagnostic::skipped(reason, 0),
107 };
108 if outer_hessian_rho.nrows() != rho_dim || outer_hessian_rho.ncols() != rho_dim {
109 return RhoUncertaintyDiagnostic::skipped(
110 format!(
111 "outer rho Hessian shape {}x{} does not match K={rho_dim}",
112 outer_hessian_rho.nrows(),
113 outer_hessian_rho.ncols()
114 ),
115 0,
116 );
117 }
118 let Some(cost_hat) = criterion(rho_hat).filter(|value| value.is_finite()) else {
119 return RhoUncertaintyDiagnostic::skipped("criterion was not finite at rho_hat", 1);
120 };
121 let Some(proposal_factor) = proposal_factor_from_hessian(outer_hessian_rho) else {
122 return RhoUncertaintyDiagnostic::skipped("outer rho Hessian was not positive definite", 1);
123 };
124
125 let mut rng = DeterministicNormal::new(seed_from_problem(rho_hat, gate.problem_size));
126 let mut log_weights = Vec::with_capacity(sample_count);
127 let mut n_evaluations = 1usize;
128 for _draw in 0..sample_count {
129 let z = Array1::from_iter((0..rho_dim).map(|coord| rng.normal(coord)));
130 let rho = rho_hat + &proposal_factor.dot(&z);
131 let half_norm_sq = 0.5 * z.iter().map(|value| value * value).sum::<f64>();
132 let log_weight = match criterion(&rho) {
133 Some(cost) if cost.is_finite() => -cost + cost_hat + half_norm_sq,
134 _ => f64::NEG_INFINITY,
135 };
136 log_weights.push(log_weight);
137 n_evaluations = n_evaluations.saturating_add(1);
138 }
139
140 let max_log_weight = log_weights
141 .iter()
142 .copied()
143 .filter(|value| value.is_finite())
144 .fold(f64::NEG_INFINITY, f64::max);
145 if !max_log_weight.is_finite() {
146 return RhoUncertaintyDiagnostic::skipped(
147 "all proposal draws had non-finite criterion values",
148 n_evaluations,
149 );
150 }
151 let weights: Vec<f64> = log_weights
152 .iter()
153 .map(|&value| {
154 if value.is_finite() {
155 (value - max_log_weight).exp()
156 } else {
157 0.0
158 }
159 })
160 .collect();
161 let (min_weight, max_weight) = weights
162 .iter()
163 .copied()
164 .fold((f64::INFINITY, f64::NEG_INFINITY), |(min_w, max_w), w| {
165 (min_w.min(w), max_w.max(w))
166 });
167 if max_weight.is_finite()
168 && min_weight.is_finite()
169 && max_weight > 0.0
170 && (max_weight - min_weight) <= 1e-12 * max_weight.max(1.0)
171 {
172 return RhoUncertaintyDiagnostic {
173 k_hat: Some(0.0),
174 n_evaluations,
175 status: RhoUncertaintyStatus::NoEvidenceOfHeavyTails,
176 };
177 }
178 let Some(psis) = pareto_smooth_weights(&weights) else {
179 return RhoUncertaintyDiagnostic::skipped(
180 "PSIS tail fit failed for rho-importance weights",
181 n_evaluations,
182 );
183 };
184 let k_hat = psis.k_hat;
185 let status = if k_hat < 0.5 {
186 RhoUncertaintyStatus::NoEvidenceOfHeavyTails
187 } else {
188 RhoUncertaintyStatus::HeavyTailsDetected { k_hat }
189 };
190 RhoUncertaintyDiagnostic {
191 k_hat: Some(k_hat),
192 n_evaluations,
193 status,
194 }
195}
196
197fn proposal_factor_from_hessian(hessian: &Array2<f64>) -> Option<Array2<f64>> {
198 let chol = cholesky_lower(hessian)?;
199 let n = chol.nrows();
200 let mut inverse_lower = Array2::<f64>::zeros((n, n));
201 for col in 0..n {
202 for row in 0..n {
203 let mut acc = if row == col { 1.0 } else { 0.0 };
204 for k in 0..row {
205 acc -= chol[[row, k]] * inverse_lower[[k, col]];
206 }
207 let diagonal = chol[[row, row]];
208 if !(diagonal.is_finite() && diagonal > 0.0) {
209 return None;
210 }
211 inverse_lower[[row, col]] = acc / diagonal;
212 }
213 }
214 let mut factor = Array2::<f64>::zeros((n, n));
215 for row in 0..n {
216 for col in 0..n {
217 factor[[row, col]] = inverse_lower[[col, row]];
218 }
219 }
220 Some(factor)
221}
222
223fn cholesky_lower(matrix: &Array2<f64>) -> Option<Array2<f64>> {
224 let n = matrix.nrows();
225 if n == 0 || matrix.ncols() != n || matrix.iter().any(|value| !value.is_finite()) {
226 return None;
227 }
228 let mut lower = Array2::<f64>::zeros((n, n));
229 for row in 0..n {
230 for col in 0..=row {
231 let mut acc = matrix[[row, col]];
232 for k in 0..col {
233 acc -= lower[[row, k]] * lower[[col, k]];
234 }
235 if row == col {
236 if !(acc.is_finite() && acc > 0.0) {
237 return None;
238 }
239 lower[[row, col]] = acc.sqrt();
240 } else {
241 let diagonal = lower[[col, col]];
242 if !(diagonal.is_finite() && diagonal > 0.0) {
243 return None;
244 }
245 lower[[row, col]] = acc / diagonal;
246 }
247 }
248 }
249 Some(lower)
250}
251
252fn seed_from_problem(rho_hat: &Array1<f64>, size: RhoUncertaintyProblemSize) -> u64 {
253 let mut state = 0xcbf2_9ce4_8422_2325_u64;
254 mix_u64(&mut state, size.n_obs.unwrap_or(0) as u64);
255 mix_u64(&mut state, size.p_coefficients.unwrap_or(0) as u64);
256 mix_u64(&mut state, rho_hat.len() as u64);
257 for value in rho_hat {
258 mix_u64(&mut state, value.to_bits());
259 }
260 state
261}
262
263fn mix_u64(state: &mut u64, value: u64) {
264 for byte in value.to_le_bytes() {
265 *state ^= u64::from(byte);
266 *state = state.wrapping_mul(0x0000_0100_0000_01b3);
267 }
268}
269
270struct DeterministicNormal {
271 state: u64,
272 spare: Option<f64>,
273}
274
275impl DeterministicNormal {
276 fn new(seed: u64) -> Self {
277 Self {
278 state: seed,
279 spare: None,
280 }
281 }
282
283 fn normal(&mut self, coord: usize) -> f64 {
284 if let Some(value) = self.spare.take() {
285 return value;
286 }
287 mix_u64(&mut self.state, coord as u64);
288 let u1 = self.uniform().max(1e-300);
289 let u2 = self.uniform();
290 let radius = (-2.0 * u1.ln()).sqrt();
291 let angle = 2.0 * std::f64::consts::PI * u2;
292 self.spare = Some(radius * angle.sin());
293 radius * angle.cos()
294 }
295
296 fn uniform(&mut self) -> f64 {
297 self.state = self.state.wrapping_add(0x9e37_79b9_7f4a_7c15);
298 let mut z = self.state;
299 z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
300 z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
301 z ^= z >> 31;
302 ((z >> 11) as f64 + 0.5) / (1_u64 << 53) as f64
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use ndarray::array;
310
311 fn gaussian_criterion(
312 rho_hat: Array1<f64>,
313 hessian: Array2<f64>,
314 ) -> impl FnMut(&Array1<f64>) -> Option<f64> {
315 move |rho: &Array1<f64>| {
316 let delta = rho - &rho_hat;
317 Some(0.5 * delta.dot(&hessian.dot(&delta)))
318 }
319 }
320
321 #[test]
322 fn near_gaussian_target_has_no_heavy_tail_evidence_at_probe_points() {
323 let rho_hat = array![0.2, -0.3];
324 let hessian = array![[2.5, 0.2], [0.2, 1.8]];
325 let diagnostic = rho_uncertainty_diagnostic(
326 &rho_hat,
327 &hessian,
328 RhoUncertaintyCostGate {
329 sample_count: 32,
330 problem_size: RhoUncertaintyProblemSize {
331 n_obs: Some(40),
332 p_coefficients: Some(8),
333 },
334 },
335 gaussian_criterion(rho_hat.clone(), hessian.clone()),
336 );
337 assert!(
338 matches!(
339 diagnostic.status,
340 RhoUncertaintyStatus::NoEvidenceOfHeavyTails
341 ),
342 "near-Gaussian rho posterior should not show heavy-tail evidence at the probe \
343 points, got {diagnostic:?}"
344 );
345 assert!(
346 diagnostic.k_hat.expect("k_hat") < 0.5,
347 "near-Gaussian target should have k_hat below 0.5"
348 );
349 }
350
351 #[test]
352 fn weak_identification_orders_above_gaussian_case() {
353 let rho_hat = array![0.0];
354 let hessian = array![[5.0]];
355 let gate = RhoUncertaintyCostGate {
356 sample_count: 64,
357 problem_size: RhoUncertaintyProblemSize {
358 n_obs: Some(12),
359 p_coefficients: Some(4),
360 },
361 };
362 let gaussian = rho_uncertainty_diagnostic(
363 &rho_hat,
364 &hessian,
365 gate,
366 gaussian_criterion(rho_hat.clone(), hessian.clone()),
367 );
368 let weak = rho_uncertainty_diagnostic(&rho_hat, &hessian, gate, |rho| {
369 Some((1.0 + rho[0] * rho[0]).ln())
370 });
371 assert!(
372 weak.k_hat.expect("weak k_hat") > gaussian.k_hat.expect("gaussian k_hat"),
373 "weak rho identification should increase k_hat: weak={weak:?} gaussian={gaussian:?}"
374 );
375 }
376
377 #[test]
378 fn diagnostic_is_bit_deterministic() {
379 let rho_hat = array![0.7];
380 let hessian = array![[1.4]];
381 let gate = RhoUncertaintyCostGate {
382 sample_count: 32,
383 problem_size: RhoUncertaintyProblemSize {
384 n_obs: Some(80),
385 p_coefficients: Some(9),
386 },
387 };
388 let a = rho_uncertainty_diagnostic(
389 &rho_hat,
390 &hessian,
391 gate,
392 gaussian_criterion(rho_hat.clone(), hessian.clone()),
393 );
394 let b = rho_uncertainty_diagnostic(
395 &rho_hat,
396 &hessian,
397 gate,
398 gaussian_criterion(rho_hat.clone(), hessian.clone()),
399 );
400 assert_eq!(a, b);
401 }
402
403 #[test]
404 fn cost_gate_skips_large_problem() {
405 let rho_hat = array![0.0, 0.0, 0.0, 0.0, 0.0];
406 let hessian = array![
407 [1.0, 0.0, 0.0, 0.0, 0.0],
408 [0.0, 1.0, 0.0, 0.0, 0.0],
409 [0.0, 0.0, 1.0, 0.0, 0.0],
410 [0.0, 0.0, 0.0, 1.0, 0.0],
411 [0.0, 0.0, 0.0, 0.0, 1.0],
412 ];
413 let diagnostic = rho_uncertainty_diagnostic(
414 &rho_hat,
415 &hessian,
416 RhoUncertaintyCostGate::default(),
417 |_| Some(0.0),
418 );
419 assert!(matches!(
420 diagnostic.status,
421 RhoUncertaintyStatus::Skipped { .. }
422 ));
423 }
424}