1use rand::SeedableRng;
48use rand::rngs::StdRng;
49
50use crate::differential_evolution;
51use crate::{DEConfigBuilder, init_latin_hypercube::init_latin_hypercube};
52use ndarray::{Array1, Array2};
53
54#[derive(Clone)]
60pub enum Prior<const D: usize> {
61 Uniform {
63 bounds: [(f64, f64); D],
65 },
66 Gaussian {
70 mean: [f64; D],
72 cov_diag: [f64; D],
74 truncation_sigmas: f64,
77 },
78 Custom {
84 bounds: [(f64, f64); D],
86 density: std::sync::Arc<dyn Fn([f64; D]) -> f64 + Send + Sync>,
88 },
89}
90
91impl<const D: usize> std::fmt::Debug for Prior<D> {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Prior::Uniform { bounds } => f
95 .debug_struct("Prior::Uniform")
96 .field("bounds", bounds)
97 .finish(),
98 Prior::Gaussian {
99 mean,
100 cov_diag,
101 truncation_sigmas,
102 } => f
103 .debug_struct("Prior::Gaussian")
104 .field("mean", mean)
105 .field("cov_diag", cov_diag)
106 .field("truncation_sigmas", truncation_sigmas)
107 .finish(),
108 Prior::Custom { bounds, .. } => f
109 .debug_struct("Prior::Custom")
110 .field("bounds", bounds)
111 .field("density", &"<closure>")
112 .finish(),
113 }
114 }
115}
116
117impl<const D: usize> Prior<D> {
118 pub fn validate(&self) -> Result<(), AreaError> {
120 match self {
121 Prior::Uniform { bounds } | Prior::Custom { bounds, .. } => {
122 for (i, (lo, hi)) in bounds.iter().enumerate() {
123 if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
124 return Err(AreaError::InvalidPrior(format!(
125 "axis {} bounds [{}, {}] are degenerate",
126 i, lo, hi
127 )));
128 }
129 }
130 Ok(())
131 }
132 Prior::Gaussian {
133 cov_diag,
134 truncation_sigmas,
135 ..
136 } => {
137 if !truncation_sigmas.is_finite() || *truncation_sigmas <= 0.0 {
138 return Err(AreaError::InvalidPrior(format!(
139 "Gaussian truncation_sigmas must be > 0, got {}",
140 truncation_sigmas
141 )));
142 }
143 for (i, &v) in cov_diag.iter().enumerate() {
144 if !v.is_finite() || v <= 0.0 {
145 return Err(AreaError::InvalidPrior(format!(
146 "Gaussian variance on axis {} must be > 0, got {}",
147 i, v
148 )));
149 }
150 }
151 Ok(())
152 }
153 }
154 }
155
156 pub fn bounding_box(&self) -> [(f64, f64); D] {
159 match self {
160 Prior::Uniform { bounds } | Prior::Custom { bounds, .. } => *bounds,
161 Prior::Gaussian {
162 mean,
163 cov_diag,
164 truncation_sigmas,
165 } => {
166 let mut out = [(0.0_f64, 0.0_f64); D];
167 for i in 0..D {
168 let sigma = cov_diag[i].sqrt();
169 out[i] = (
170 mean[i] - truncation_sigmas * sigma,
171 mean[i] + truncation_sigmas * sigma,
172 );
173 }
174 out
175 }
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
182pub enum Quadrature<const D: usize> {
183 Sobol {
187 num_points: usize,
189 seed: u64,
191 },
192 LatinHypercube {
194 num_points: usize,
196 seed: u64,
198 },
199 GaussLegendre {
205 points_per_axis: usize,
207 },
208}
209
210#[derive(Debug, Clone, Copy)]
212pub enum AreaScalarisation {
213 ExpectedValue,
216 WorstCase {
220 inner_maxiter: usize,
222 inner_seed: u64,
224 },
225 Cvar {
229 alpha: f64,
231 },
232}
233
234#[derive(Debug, thiserror::Error)]
236pub enum AreaError {
237 #[error("invalid prior: {0}")]
239 InvalidPrior(String),
240 #[error("invalid quadrature: {0}")]
242 InvalidQuadrature(String),
243 #[error("incompatible prior/quadrature: {0}")]
245 IncompatiblePriorQuadrature(String),
246 #[error("inner worst-case search failed: {0}")]
248 InnerSearchFailed(String),
249}
250
251pub fn build_quadrature_points<const D: usize>(
262 prior: &Prior<D>,
263 quadrature: &Quadrature<D>,
264) -> Result<(Vec<[f64; D]>, Vec<f64>), AreaError> {
265 prior.validate()?;
266 let bounds = prior.bounding_box();
267
268 match quadrature {
269 Quadrature::Sobol { num_points, seed } => {
270 if *num_points == 0 {
271 return Err(AreaError::InvalidQuadrature(
272 "Sobol num_points must be > 0".into(),
273 ));
274 }
275 let raw = sobol_unit(*num_points, *seed);
276 transform_unit_samples(&raw, prior, &bounds)
277 }
278 Quadrature::LatinHypercube { num_points, seed } => {
279 if *num_points == 0 {
280 return Err(AreaError::InvalidQuadrature(
281 "LatinHypercube num_points must be > 0".into(),
282 ));
283 }
284 let raw = latin_hypercube_unit::<D>(*num_points, *seed);
285 transform_unit_samples(&raw, prior, &bounds)
286 }
287 Quadrature::GaussLegendre { points_per_axis } => {
288 if *points_per_axis == 0 {
289 return Err(AreaError::InvalidQuadrature(
290 "GaussLegendre points_per_axis must be > 0".into(),
291 ));
292 }
293 match prior {
294 Prior::Uniform { bounds } => Ok(gauss_legendre_tensor(*points_per_axis, bounds)),
295 Prior::Custom { bounds, density } => {
296 let (pts, mut weights) = gauss_legendre_tensor(*points_per_axis, bounds);
298 for (p, w) in pts.iter().zip(weights.iter_mut()) {
299 *w *= density(*p).max(0.0);
300 }
301 let total: f64 = weights.iter().sum();
302 if total <= 0.0 {
303 return Err(AreaError::InvalidPrior(
304 "Custom density evaluated to zero on every quadrature node".into(),
305 ));
306 }
307 for w in weights.iter_mut() {
308 *w /= total;
309 }
310 Ok((pts, weights))
311 }
312 Prior::Gaussian { .. } => Err(AreaError::IncompatiblePriorQuadrature(
313 "GaussLegendre on a Gaussian prior would require Gauss–Hermite; \
314 use Sobol or LatinHypercube for unbounded priors"
315 .into(),
316 )),
317 }
318 }
319 }
320}
321
322pub fn evaluate_area_loss<F, const D: usize>(
328 loss: &F,
329 params: &[f64],
330 prior: &Prior<D>,
331 quadrature: &Quadrature<D>,
332 scalarisation: AreaScalarisation,
333) -> f64
334where
335 F: Fn(&[f64], [f64; D]) -> f64 + Sync,
336{
337 try_evaluate_area_loss(loss, params, prior, quadrature, scalarisation)
338 .unwrap_or_else(|e| panic!("evaluate_area_loss: {e}"))
339}
340
341pub fn try_evaluate_area_loss<F, const D: usize>(
343 loss: &F,
344 params: &[f64],
345 prior: &Prior<D>,
346 quadrature: &Quadrature<D>,
347 scalarisation: AreaScalarisation,
348) -> Result<f64, AreaError>
349where
350 F: Fn(&[f64], [f64; D]) -> f64 + Sync,
351{
352 match scalarisation {
353 AreaScalarisation::WorstCase {
354 inner_maxiter,
355 inner_seed,
356 } => worst_case_via_de(loss, params, prior, inner_maxiter, inner_seed),
357 AreaScalarisation::ExpectedValue => {
358 let (points, weights) = build_quadrature_points(prior, quadrature)?;
359 let mut acc = 0.0;
360 for (p, w) in points.iter().zip(weights.iter()) {
361 acc += w * loss(params, *p);
362 }
363 Ok(acc)
364 }
365 AreaScalarisation::Cvar { alpha } => {
366 if !(0.0..=1.0).contains(&alpha) || alpha <= 0.0 {
367 return Err(AreaError::InvalidQuadrature(format!(
368 "CVaR alpha must be in (0, 1], got {}",
369 alpha
370 )));
371 }
372 let (points, weights) = build_quadrature_points(prior, quadrature)?;
373 let mut wl: Vec<(f64, f64)> = points
375 .iter()
376 .zip(weights.iter())
377 .map(|(p, &w)| (loss(params, *p), w))
378 .collect();
379 wl.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
381 let mut acc_loss = 0.0;
383 let mut acc_mass = 0.0;
384 for (l, w) in &wl {
385 let take = (alpha - acc_mass).min(*w);
386 if take <= 0.0 {
387 break;
388 }
389 acc_loss += take * l;
390 acc_mass += take;
391 if acc_mass >= alpha {
392 break;
393 }
394 }
395 if acc_mass <= 0.0 {
396 return Err(AreaError::InvalidQuadrature(
397 "CVaR encountered zero total importance weight".into(),
398 ));
399 }
400 Ok(acc_loss / acc_mass)
401 }
402 }
403}
404
405fn sobol_unit<const D: usize>(num_points: usize, _seed: u64) -> Vec<[f64; D]> {
410 let unit_bounds: Vec<(f64, f64)> = (0..D).map(|_| (0.0, 1.0)).collect();
414 let raw = crate::init_sobol::init_halton(D, num_points, &unit_bounds);
415 raw.into_iter()
416 .map(|v| {
417 let mut out = [0.0_f64; D];
418 for (i, x) in v.into_iter().enumerate().take(D) {
419 out[i] = x;
420 }
421 out
422 })
423 .collect()
424}
425
426fn latin_hypercube_unit<const D: usize>(num_points: usize, seed: u64) -> Vec<[f64; D]> {
427 let lower = Array1::<f64>::zeros(D);
428 let upper = Array1::<f64>::ones(D);
429 let is_free = vec![true; D];
430 let mut rng = StdRng::seed_from_u64(seed);
431 let m: Array2<f64> = init_latin_hypercube(D, num_points, &lower, &upper, &is_free, &mut rng);
432 (0..num_points)
433 .map(|row| {
434 let mut out = [0.0_f64; D];
435 for col in 0..D {
436 out[col] = m[(row, col)];
437 }
438 out
439 })
440 .collect()
441}
442
443fn transform_unit_samples<const D: usize>(
444 raw: &[[f64; D]],
445 prior: &Prior<D>,
446 bounds: &[(f64, f64); D],
447) -> Result<(Vec<[f64; D]>, Vec<f64>), AreaError> {
448 let n = raw.len();
449 let uniform_weight = 1.0 / n as f64;
450
451 match prior {
452 Prior::Uniform { .. } => {
453 let pts: Vec<[f64; D]> = raw
454 .iter()
455 .map(|u| {
456 let mut out = [0.0_f64; D];
457 for i in 0..D {
458 out[i] = bounds[i].0 + u[i] * (bounds[i].1 - bounds[i].0);
459 }
460 out
461 })
462 .collect();
463 Ok((pts, vec![uniform_weight; n]))
464 }
465 Prior::Gaussian { mean, cov_diag, .. } => {
466 let mut pts: Vec<[f64; D]> = Vec::with_capacity(n);
470 for u in raw {
471 let mut out = [0.0_f64; D];
472 for i in 0..D {
473 let sigma = cov_diag[i].sqrt();
474 let z_lo = (bounds[i].0 - mean[i]) / sigma;
476 let z_hi = (bounds[i].1 - mean[i]) / sigma;
477 let p_lo = standard_normal_cdf(z_lo);
478 let p_hi = standard_normal_cdf(z_hi);
479 let u_remap = p_lo + u[i] * (p_hi - p_lo);
480 let z = inv_standard_normal(u_remap);
481 out[i] = mean[i] + sigma * z;
482 }
483 pts.push(out);
484 }
485 Ok((pts, vec![uniform_weight; n]))
486 }
487 Prior::Custom { density, .. } => {
488 let pts: Vec<[f64; D]> = raw
490 .iter()
491 .map(|u| {
492 let mut out = [0.0_f64; D];
493 for i in 0..D {
494 out[i] = bounds[i].0 + u[i] * (bounds[i].1 - bounds[i].0);
495 }
496 out
497 })
498 .collect();
499 let mut weights: Vec<f64> = pts.iter().map(|p| density(*p).max(0.0)).collect();
500 let total: f64 = weights.iter().sum();
501 if total <= 0.0 {
502 return Err(AreaError::InvalidPrior(
503 "Custom density evaluated to zero on every sampled point".into(),
504 ));
505 }
506 for w in weights.iter_mut() {
507 *w /= total;
508 }
509 Ok((pts, weights))
510 }
511 }
512}
513
514fn gauss_legendre_tensor<const D: usize>(
515 points_per_axis: usize,
516 bounds: &[(f64, f64); D],
517) -> (Vec<[f64; D]>, Vec<f64>) {
518 let (nodes_unit, weights_unit) = gauss_legendre_1d(points_per_axis);
519 let mut nodes_per_axis: [Vec<f64>; D] = std::array::from_fn(|_| Vec::new());
523 let mut weights_per_axis: [Vec<f64>; D] = std::array::from_fn(|_| Vec::new());
524 for i in 0..D {
525 let (lo, hi) = bounds[i];
526 let mid = 0.5 * (hi + lo);
527 let half = 0.5 * (hi - lo);
528 let mut nodes = Vec::with_capacity(points_per_axis);
529 let mut weights = Vec::with_capacity(points_per_axis);
530 for k in 0..points_per_axis {
531 nodes.push(mid + half * nodes_unit[k]);
532 weights.push(half * weights_unit[k]);
533 }
534 nodes_per_axis[i] = nodes;
535 weights_per_axis[i] = weights;
536 }
537
538 let total: usize = points_per_axis.pow(D as u32);
539 let mut pts: Vec<[f64; D]> = Vec::with_capacity(total);
540 let mut wts: Vec<f64> = Vec::with_capacity(total);
541 for idx in 0..total {
542 let mut pt = [0.0_f64; D];
543 let mut w = 1.0_f64;
544 let mut k = idx;
545 for i in 0..D {
546 let ki = k % points_per_axis;
547 k /= points_per_axis;
548 pt[i] = nodes_per_axis[i][ki];
549 w *= weights_per_axis[i][ki];
550 }
551 pts.push(pt);
552 wts.push(w);
553 }
554
555 let total_w: f64 = wts.iter().sum();
557 if total_w > 0.0 {
558 for w in wts.iter_mut() {
559 *w /= total_w;
560 }
561 }
562
563 (pts, wts)
564}
565
566fn gauss_legendre_1d(n: usize) -> (Vec<f64>, Vec<f64>) {
573 if n == 0 {
574 return (Vec::new(), Vec::new());
575 }
576 if n == 1 {
577 return (vec![0.0], vec![2.0]);
578 }
579
580 let mut nodes = vec![0.0_f64; n];
582 let mut weights = vec![0.0_f64; n];
583 for i in 0..n {
584 let mut x = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
586 for _ in 0..50 {
587 let mut p_prev2 = 1.0_f64;
589 let mut p_prev1 = x;
590 for k in 1..n {
591 let p_next =
592 ((2.0 * k as f64 + 1.0) * x * p_prev1 - k as f64 * p_prev2) / (k as f64 + 1.0);
593 p_prev2 = p_prev1;
594 p_prev1 = p_next;
595 }
596 let p_n = p_prev1;
598 let dp_n = n as f64 * (x * p_n - p_prev2) / (x * x - 1.0);
599 let dx = p_n / dp_n;
600 x -= dx;
601 if dx.abs() < 1e-15 {
602 break;
603 }
604 }
605 let mut p_prev2 = 1.0_f64;
607 let mut p_prev1 = x;
608 for k in 1..n {
609 let p_next =
610 ((2.0 * k as f64 + 1.0) * x * p_prev1 - k as f64 * p_prev2) / (k as f64 + 1.0);
611 p_prev2 = p_prev1;
612 p_prev1 = p_next;
613 }
614 let p_n = p_prev1;
615 let dp_n = n as f64 * (x * p_n - p_prev2) / (x * x - 1.0);
616 nodes[i] = x;
617 weights[i] = 2.0 / ((1.0 - x * x) * dp_n * dp_n);
618 }
619
620 let mut idx: Vec<usize> = (0..n).collect();
622 idx.sort_by(|&a, &b| {
623 nodes[a]
624 .partial_cmp(&nodes[b])
625 .unwrap_or(std::cmp::Ordering::Equal)
626 });
627 let nodes_sorted: Vec<f64> = idx.iter().map(|&i| nodes[i]).collect();
628 let weights_sorted: Vec<f64> = idx.iter().map(|&i| weights[i]).collect();
629 (nodes_sorted, weights_sorted)
630}
631
632fn worst_case_via_de<F, const D: usize>(
633 loss: &F,
634 params: &[f64],
635 prior: &Prior<D>,
636 inner_maxiter: usize,
637 inner_seed: u64,
638) -> Result<f64, AreaError>
639where
640 F: Fn(&[f64], [f64; D]) -> f64 + Sync,
641{
642 prior.validate()?;
643 let bounds_arr = prior.bounding_box();
644 let bounds_vec: Vec<(f64, f64)> = bounds_arr.iter().copied().collect();
645
646 let neg_loss = |p_vec: &Array1<f64>| -> f64 {
648 let mut p = [0.0_f64; D];
649 for i in 0..D {
650 p[i] = p_vec[i];
651 }
652 -loss(params, p)
653 };
654
655 let cfg = DEConfigBuilder::new()
656 .maxiter(inner_maxiter.max(5))
657 .popsize(8)
658 .seed(inner_seed)
659 .build()
660 .map_err(|e| AreaError::InnerSearchFailed(format!("{e}")))?;
661
662 let report = differential_evolution(&neg_loss, &bounds_vec, cfg)
663 .map_err(|e| AreaError::InnerSearchFailed(format!("{e}")))?;
664
665 Ok(-report.fun)
666}
667
668fn standard_normal_cdf(x: f64) -> f64 {
669 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
671}
672
673#[allow(clippy::excessive_precision)]
674fn inv_standard_normal(u: f64) -> f64 {
675 let u = u.clamp(1e-12, 1.0 - 1e-12);
680 let a = [
681 -3.969683028665376e+01,
682 2.209460984245205e+02,
683 -2.759285104469687e+02,
684 1.383577518672690e+02,
685 -3.066479806614716e+01,
686 2.506628277459239e+00,
687 ];
688 let b = [
689 -5.447609879822406e+01,
690 1.615858368580409e+02,
691 -1.556989798598866e+02,
692 6.680131188771972e+01,
693 -1.328068155288572e+01,
694 ];
695 let c = [
696 -7.784894002430293e-03,
697 -3.223964580411365e-01,
698 -2.400758277161838e+00,
699 -2.549732539343734e+00,
700 4.374664141464968e+00,
701 2.938163982698783e+00,
702 ];
703 let d = [
704 7.784695709041462e-03,
705 3.224671290700398e-01,
706 2.445134137142996e+00,
707 3.754408661907416e+00,
708 ];
709
710 let plow = 0.02425;
711 let phigh = 1.0 - plow;
712
713 if u < plow {
714 let q = (-2.0 * u.ln()).sqrt();
715 let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
716 let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
717 num / den
718 } else if u <= phigh {
719 let q = u - 0.5;
720 let r = q * q;
721 (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
722 / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0)
723 } else {
724 let q = (-2.0 * (1.0 - u).ln()).sqrt();
725 let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
726 let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
727 -num / den
728 }
729}
730
731fn erf(x: f64) -> f64 {
732 let sign = x.signum();
734 let x = x.abs();
735 let a1 = 0.254829592;
736 let a2 = -0.284496736;
737 let a3 = 1.421413741;
738 let a4 = -1.453152027;
739 let a5 = 1.061405429;
740 let p = 0.3275911;
741 let t = 1.0 / (1.0 + p * x);
742 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
743 sign * y
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749
750 #[test]
751 fn sobol_uniform_integrates_p_squared() {
752 let prior: Prior<1> = Prior::Uniform {
754 bounds: [(0.0, 1.0)],
755 };
756 let q: Quadrature<1> = Quadrature::Sobol {
757 num_points: 1024,
758 seed: 0,
759 };
760 let loss = |_p: &[f64], pt: [f64; 1]| pt[0] * pt[0];
761 let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
762 assert!((v - 1.0 / 3.0).abs() < 1e-2, "got {}", v);
763 }
764
765 #[test]
766 fn lhs_uniform_2d_integrates_constant_to_constant() {
767 let prior: Prior<2> = Prior::Uniform {
768 bounds: [(0.0, 2.0), (-1.0, 3.0)],
769 };
770 let q: Quadrature<2> = Quadrature::LatinHypercube {
771 num_points: 256,
772 seed: 7,
773 };
774 let loss = |_p: &[f64], _pt: [f64; 2]| 5.5;
775 let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
776 assert!((v - 5.5).abs() < 1e-9, "got {}", v);
777 }
778
779 #[test]
780 fn gauss_legendre_exactness_polynomial_degree_three() {
781 let prior: Prior<1> = Prior::Uniform {
784 bounds: [(-1.0, 1.0)],
785 };
786 let q: Quadrature<1> = Quadrature::GaussLegendre { points_per_axis: 2 };
787 let loss = |_p: &[f64], pt: [f64; 1]| 3.0 * pt[0].powi(3) - 2.0 * pt[0].powi(2) + pt[0];
788 let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
789 assert!((v - (-2.0 / 3.0)).abs() < 1e-9, "got {}", v);
791 }
792
793 #[test]
794 fn worst_case_finds_known_max() {
795 let prior: Prior<1> = Prior::Uniform {
797 bounds: [(0.0, 1.0)],
798 };
799 let q: Quadrature<1> = Quadrature::Sobol {
800 num_points: 16,
801 seed: 0,
802 };
803 let loss = |_p: &[f64], pt: [f64; 1]| -(pt[0] - 0.4).powi(2);
804 let v = evaluate_area_loss(
805 &loss,
806 &[0.0],
807 &prior,
808 &q,
809 AreaScalarisation::WorstCase {
810 inner_maxiter: 60,
811 inner_seed: 1,
812 },
813 );
814 assert!(v > -1e-3, "expected ~0, got {}", v);
815 }
816
817 #[test]
818 fn gaussian_prior_expected_value_matches_known_mean() {
819 let prior: Prior<1> = Prior::Gaussian {
821 mean: [1.0],
822 cov_diag: [0.25],
823 truncation_sigmas: 5.0,
824 };
825 let q: Quadrature<1> = Quadrature::Sobol {
826 num_points: 4096,
827 seed: 0,
828 };
829 let loss = |_p: &[f64], pt: [f64; 1]| pt[0] * pt[0];
830 let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
831 assert!((v - 1.25).abs() < 5e-2, "got {}", v);
832 }
833
834 #[test]
835 fn cvar_concentrates_on_tail() {
836 let prior: Prior<1> = Prior::Uniform {
839 bounds: [(0.0, 1.0)],
840 };
841 let q: Quadrature<1> = Quadrature::Sobol {
842 num_points: 1024,
843 seed: 0,
844 };
845 let loss = |_p: &[f64], pt: [f64; 1]| if pt[0] > 0.9 { 100.0 } else { 1.0 };
846 let mean = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
847 let cvar = evaluate_area_loss(
848 &loss,
849 &[0.0],
850 &prior,
851 &q,
852 AreaScalarisation::Cvar { alpha: 0.1 },
853 );
854 assert!(
855 cvar > mean * 5.0,
856 "cvar {} should be >> mean {}",
857 cvar,
858 mean
859 );
860 }
861
862 #[test]
863 fn rejects_zero_quadrature_points() {
864 let prior: Prior<1> = Prior::Uniform {
865 bounds: [(0.0, 1.0)],
866 };
867 let q: Quadrature<1> = Quadrature::Sobol {
868 num_points: 0,
869 seed: 0,
870 };
871 let loss = |_p: &[f64], _pt: [f64; 1]| 1.0;
872 assert!(
873 try_evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue)
874 .is_err()
875 );
876 }
877
878 #[test]
879 fn rejects_degenerate_uniform_bounds() {
880 let prior: Prior<1> = Prior::Uniform {
881 bounds: [(1.0, 1.0)],
882 };
883 assert!(prior.validate().is_err());
884 }
885
886 #[test]
887 fn gauss_legendre_1d_nodes_symmetric() {
888 for n in 2..=6 {
889 let (nodes, weights) = gauss_legendre_1d(n);
890 assert_eq!(nodes.len(), n);
891 assert_eq!(weights.len(), n);
892 let total_w: f64 = weights.iter().sum();
893 assert!(
894 (total_w - 2.0).abs() < 1e-10,
895 "n={}: total_w={}",
896 n,
897 total_w
898 );
899 for i in 0..n / 2 {
901 assert!(
902 (nodes[i] + nodes[n - 1 - i]).abs() < 1e-10,
903 "n={}, i={}: nodes={:?}",
904 n,
905 i,
906 nodes
907 );
908 assert!(
909 (weights[i] - weights[n - 1 - i]).abs() < 1e-10,
910 "n={}, i={}: weights={:?}",
911 n,
912 i,
913 weights
914 );
915 }
916 }
917 }
918
919 #[test]
920 fn standard_normal_cdf_known_values() {
921 assert!((standard_normal_cdf(0.0) - 0.5).abs() < 1e-6);
922 assert!((standard_normal_cdf(1.0) - 0.8413447).abs() < 1e-4);
923 assert!((standard_normal_cdf(-1.0) - 0.1586553).abs() < 1e-4);
924 }
925
926 #[test]
927 fn inv_standard_normal_round_trip() {
928 for &p in &[0.05_f64, 0.25, 0.5, 0.75, 0.95] {
929 let z = inv_standard_normal(p);
930 let p2 = standard_normal_cdf(z);
931 assert!((p - p2).abs() < 1e-3, "p={}, z={}, p2={}", p, z, p2);
932 }
933 }
934}