Skip to main content

math_audio_optimisation/
continuous_area.rs

1//! Continuous-prior loss integration for spatial / area-based optimization.
2//!
3//! Generic over dimension via const generics: `D=1` for a line of seats, `D=2`
4//! for a listening rectangle, `D=3` for a head-volume sweep, and so on.
5//!
6//! # Three building blocks
7//!
8//! - [`Prior`] — the probability distribution π(p) over positions p ∈ R^D.
9//!   Currently `Uniform` over an axis-aligned box and axis-aligned `Gaussian`
10//!   are first-class; `Custom` accepts an arbitrary density.
11//! - [`Quadrature`] — how to discretise the integral into Q evaluation points.
12//!   Supports Sobol (low-discrepancy QMC), Latin-Hypercube, and Gauss–Legendre
13//!   tensor-product. Sobol/LH are seeded for determinism.
14//! - [`AreaScalarisation`] — what to do with the Q losses: expected value,
15//!   worst-case (max), or CVaR (mean of the worst α-tail).
16//!
17//! # The high-level call
18//!
19//! [`evaluate_area_loss`] takes a base loss `L(params, p)`, a [`Prior`], a
20//! [`Quadrature`], and a [`AreaScalarisation`] and returns one scalar that an
21//! outer optimizer can minimise. The outer optimizer never sees the
22//! quadrature; it just sees a robust scalar objective.
23//!
24//! # Cost model
25//!
26//! Each outer fitness call costs Q base-loss evaluations for `ExpectedValue`
27//! and `CVaR`. `WorstCase` runs a small inner DE search per outer call —
28//! callers should be aware this is more expensive (typically 10×–50× a single
29//! base-loss eval).
30//!
31//! ```rust,no_run
32//! use math_audio_optimisation::continuous_area::{
33//!     AreaScalarisation, Prior, Quadrature, evaluate_area_loss,
34//! };
35//!
36//! // Minimise expected value of (params[0] - p)^2 with p ~ Uniform([-1,1])
37//! let prior: Prior<1> = Prior::Uniform { bounds: [(-1.0, 1.0)] };
38//! let quadrature: Quadrature<1> = Quadrature::Sobol {
39//!     num_points: 128,
40//!     seed: 0,
41//! };
42//! let loss = |params: &[f64], p: [f64; 1]| (params[0] - p[0]).powi(2);
43//! let value = evaluate_area_loss(&loss, &[0.0], &prior, &quadrature, AreaScalarisation::ExpectedValue);
44//! assert!((value - 1.0 / 3.0).abs() < 0.05);
45//! ```
46
47use rand::SeedableRng;
48use rand::rngs::StdRng;
49
50use crate::differential_evolution;
51use crate::{DEConfigBuilder, init_latin_hypercube::init_latin_hypercube};
52use ndarray::{Array1, Array2};
53
54/// Probability density / sampling region over R^D.
55///
56/// All variants are validated by [`Prior::validate`]. The sample-space
57/// described here is *the support of the prior*; quadrature points are drawn
58/// inside this support and transformed if needed.
59#[derive(Clone)]
60pub enum Prior<const D: usize> {
61    /// Uniform density on the axis-aligned box `bounds[i] = (lo, hi)`.
62    Uniform {
63        /// Per-axis lower / upper bounds. Lower must be strictly less than upper.
64        bounds: [(f64, f64); D],
65    },
66    /// Axis-aligned Gaussian. `cov_diag[i]` holds σ² along axis i.
67    /// Truncated to ±k·σ in [`Quadrature::Sobol`] / [`Quadrature::LatinHypercube`]
68    /// via inverse-CDF; the `truncation_sigmas` field controls k (default 4.0).
69    Gaussian {
70        /// Per-axis means.
71        mean: [f64; D],
72        /// Per-axis variances (must be > 0).
73        cov_diag: [f64; D],
74        /// Truncation in standard deviations (samples are clamped after inverse-CDF).
75        /// Default 4.0 captures > 99.99 % of the mass and avoids runaway tails.
76        truncation_sigmas: f64,
77    },
78    /// Arbitrary density specified pointwise. The closure must return a
79    /// non-negative density. Quadrature for `Custom` priors is restricted to
80    /// `Sobol` / `LatinHypercube` over an axis-aligned bounding box that the
81    /// caller supplies (the closure is not used by quadrature; only by
82    /// importance-weighting inside [`evaluate_area_loss`]).
83    Custom {
84        /// Bounding box used by sampling-based quadratures.
85        bounds: [(f64, f64); D],
86        /// Density evaluated at each sampled point. Must be non-negative.
87        density: std::sync::Arc<dyn Fn([f64; D]) -> f64 + Send + Sync>,
88    },
89}
90
91impl<const D: usize> std::fmt::Debug for Prior<D> {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match self {
94            Prior::Uniform { bounds } => f
95                .debug_struct("Prior::Uniform")
96                .field("bounds", bounds)
97                .finish(),
98            Prior::Gaussian {
99                mean,
100                cov_diag,
101                truncation_sigmas,
102            } => f
103                .debug_struct("Prior::Gaussian")
104                .field("mean", mean)
105                .field("cov_diag", cov_diag)
106                .field("truncation_sigmas", truncation_sigmas)
107                .finish(),
108            Prior::Custom { bounds, .. } => f
109                .debug_struct("Prior::Custom")
110                .field("bounds", bounds)
111                .field("density", &"<closure>")
112                .finish(),
113        }
114    }
115}
116
117impl<const D: usize> Prior<D> {
118    /// Reject malformed priors before any quadrature samples are drawn.
119    pub fn validate(&self) -> Result<(), AreaError> {
120        match self {
121            Prior::Uniform { bounds } | Prior::Custom { bounds, .. } => {
122                for (i, (lo, hi)) in bounds.iter().enumerate() {
123                    if !(lo.is_finite() && hi.is_finite()) || hi <= lo {
124                        return Err(AreaError::InvalidPrior(format!(
125                            "axis {} bounds [{}, {}] are degenerate",
126                            i, lo, hi
127                        )));
128                    }
129                }
130                Ok(())
131            }
132            Prior::Gaussian {
133                cov_diag,
134                truncation_sigmas,
135                ..
136            } => {
137                if !truncation_sigmas.is_finite() || *truncation_sigmas <= 0.0 {
138                    return Err(AreaError::InvalidPrior(format!(
139                        "Gaussian truncation_sigmas must be > 0, got {}",
140                        truncation_sigmas
141                    )));
142                }
143                for (i, &v) in cov_diag.iter().enumerate() {
144                    if !v.is_finite() || v <= 0.0 {
145                        return Err(AreaError::InvalidPrior(format!(
146                            "Gaussian variance on axis {} must be > 0, got {}",
147                            i, v
148                        )));
149                    }
150                }
151                Ok(())
152            }
153        }
154    }
155
156    /// Axis-aligned bounding box used as the search space for [`Quadrature::WorstCaseSearch`]
157    /// and the integration domain for tensor-product / sampled quadratures.
158    pub fn bounding_box(&self) -> [(f64, f64); D] {
159        match self {
160            Prior::Uniform { bounds } | Prior::Custom { bounds, .. } => *bounds,
161            Prior::Gaussian {
162                mean,
163                cov_diag,
164                truncation_sigmas,
165            } => {
166                let mut out = [(0.0_f64, 0.0_f64); D];
167                for i in 0..D {
168                    let sigma = cov_diag[i].sqrt();
169                    out[i] = (
170                        mean[i] - truncation_sigmas * sigma,
171                        mean[i] + truncation_sigmas * sigma,
172                    );
173                }
174                out
175            }
176        }
177    }
178}
179
180/// Quadrature scheme for discretising the prior integral into Q sample points.
181#[derive(Debug, Clone)]
182pub enum Quadrature<const D: usize> {
183    /// Sobol low-discrepancy QMC on `[0,1]^D`, transformed to the prior support.
184    /// Convergence rate is `O(log(N)^D / N)` for smooth integrands — much
185    /// faster than Monte Carlo's `O(1/sqrt(N))`.
186    Sobol {
187        /// Number of quadrature points. Powers of two are most efficient for Sobol.
188        num_points: usize,
189        /// PRNG seed (used only for the Owen scrambling when enabled; deterministic with the same seed).
190        seed: u64,
191    },
192    /// Latin-Hypercube sampling — better tails than pure random, simpler than Sobol.
193    LatinHypercube {
194        /// Number of quadrature points.
195        num_points: usize,
196        /// PRNG seed for reproducibility.
197        seed: u64,
198    },
199    /// Gauss–Legendre tensor product over an axis-aligned box. The total point
200    /// count is `points_per_axis^D`. Exact on polynomials up to degree
201    /// `2*points_per_axis - 1` along each axis. Only valid when the prior has
202    /// finite, axis-aligned support — i.e. `Prior::Uniform` (and effectively
203    /// `Prior::Custom` if you accept the bounding-box restriction).
204    GaussLegendre {
205        /// Nodes per axis. Total points = `points_per_axis.pow(D as u32)`.
206        points_per_axis: usize,
207    },
208}
209
210/// How to scalarise the Q per-point losses into one outer-loop loss.
211#[derive(Debug, Clone, Copy)]
212pub enum AreaScalarisation {
213    /// Probability-weighted mean: ∫ L(x,p) π(p) dp. The standard "expected
214    /// loss over the listening area".
215    ExpectedValue,
216    /// max_{p ∈ support(π)} L(x, p). Robust / minimax. Implemented by
217    /// inner-search over the bounding box (ignores the density shape; the
218    /// max is taken over the support).
219    WorstCase {
220        /// Inner-search budget. 50 is usually plenty for D ≤ 3.
221        inner_maxiter: usize,
222        /// Inner-search seed.
223        inner_seed: u64,
224    },
225    /// Conditional Value-at-Risk at level α: mean of the worst α-fraction of
226    /// per-point losses. `alpha = 0.1` averages the worst 10 %; `alpha = 1.0`
227    /// degenerates to [`AreaScalarisation::ExpectedValue`].
228    Cvar {
229        /// Tail fraction in (0, 1].
230        alpha: f64,
231    },
232}
233
234/// Errors raised by continuous-area evaluation.
235#[derive(Debug, thiserror::Error)]
236pub enum AreaError {
237    /// The prior was malformed (degenerate bounds, non-positive variance, …).
238    #[error("invalid prior: {0}")]
239    InvalidPrior(String),
240    /// The quadrature was malformed (zero points, non-finite Sobol seed, …).
241    #[error("invalid quadrature: {0}")]
242    InvalidQuadrature(String),
243    /// Mixing `GaussLegendre` with a non-bounded prior, or similar invariants.
244    #[error("incompatible prior/quadrature: {0}")]
245    IncompatiblePriorQuadrature(String),
246    /// Inner DE search for `WorstCase` failed.
247    #[error("inner worst-case search failed: {0}")]
248    InnerSearchFailed(String),
249}
250
251/// Generate the quadrature points and corresponding integration weights.
252///
253/// For sampling-based quadratures (Sobol, Latin-Hypercube), the weights also
254/// fold in the prior density (importance weighting), so caller code can use a
255/// simple weighted sum. For Gauss–Legendre, the weights are the standard
256/// tensor-product Gauss–Legendre weights scaled to the bounding box.
257///
258/// Returns `(points, weights)` with `points.len() == weights.len()` and
259/// `weights.sum() == 1.0` (after normalisation against the prior's total mass
260/// over the sampled domain).
261pub fn build_quadrature_points<const D: usize>(
262    prior: &Prior<D>,
263    quadrature: &Quadrature<D>,
264) -> Result<(Vec<[f64; D]>, Vec<f64>), AreaError> {
265    prior.validate()?;
266    let bounds = prior.bounding_box();
267
268    match quadrature {
269        Quadrature::Sobol { num_points, seed } => {
270            if *num_points == 0 {
271                return Err(AreaError::InvalidQuadrature(
272                    "Sobol num_points must be > 0".into(),
273                ));
274            }
275            let raw = sobol_unit(*num_points, *seed);
276            transform_unit_samples(&raw, prior, &bounds)
277        }
278        Quadrature::LatinHypercube { num_points, seed } => {
279            if *num_points == 0 {
280                return Err(AreaError::InvalidQuadrature(
281                    "LatinHypercube num_points must be > 0".into(),
282                ));
283            }
284            let raw = latin_hypercube_unit::<D>(*num_points, *seed);
285            transform_unit_samples(&raw, prior, &bounds)
286        }
287        Quadrature::GaussLegendre { points_per_axis } => {
288            if *points_per_axis == 0 {
289                return Err(AreaError::InvalidQuadrature(
290                    "GaussLegendre points_per_axis must be > 0".into(),
291                ));
292            }
293            match prior {
294                Prior::Uniform { bounds } => Ok(gauss_legendre_tensor(*points_per_axis, bounds)),
295                Prior::Custom { bounds, density } => {
296                    // Importance-weighted GL: multiply each weight by density and renormalise.
297                    let (pts, mut weights) = gauss_legendre_tensor(*points_per_axis, bounds);
298                    for (p, w) in pts.iter().zip(weights.iter_mut()) {
299                        *w *= density(*p).max(0.0);
300                    }
301                    let total: f64 = weights.iter().sum();
302                    if total <= 0.0 {
303                        return Err(AreaError::InvalidPrior(
304                            "Custom density evaluated to zero on every quadrature node".into(),
305                        ));
306                    }
307                    for w in weights.iter_mut() {
308                        *w /= total;
309                    }
310                    Ok((pts, weights))
311                }
312                Prior::Gaussian { .. } => Err(AreaError::IncompatiblePriorQuadrature(
313                    "GaussLegendre on a Gaussian prior would require Gauss–Hermite; \
314                     use Sobol or LatinHypercube for unbounded priors"
315                        .into(),
316                )),
317            }
318        }
319    }
320}
321
322/// Evaluate a continuous-area loss.
323///
324/// `loss(params, p)` is the per-point loss; `params` is passed through opaquely
325/// — the outer optimizer owns its meaning. Returns one scalar suitable for
326/// minimisation.
327pub fn evaluate_area_loss<F, const D: usize>(
328    loss: &F,
329    params: &[f64],
330    prior: &Prior<D>,
331    quadrature: &Quadrature<D>,
332    scalarisation: AreaScalarisation,
333) -> f64
334where
335    F: Fn(&[f64], [f64; D]) -> f64 + Sync,
336{
337    try_evaluate_area_loss(loss, params, prior, quadrature, scalarisation)
338        .unwrap_or_else(|e| panic!("evaluate_area_loss: {e}"))
339}
340
341/// Fallible version of [`evaluate_area_loss`]: returns errors instead of panicking.
342pub fn try_evaluate_area_loss<F, const D: usize>(
343    loss: &F,
344    params: &[f64],
345    prior: &Prior<D>,
346    quadrature: &Quadrature<D>,
347    scalarisation: AreaScalarisation,
348) -> Result<f64, AreaError>
349where
350    F: Fn(&[f64], [f64; D]) -> f64 + Sync,
351{
352    match scalarisation {
353        AreaScalarisation::WorstCase {
354            inner_maxiter,
355            inner_seed,
356        } => worst_case_via_de(loss, params, prior, inner_maxiter, inner_seed),
357        AreaScalarisation::ExpectedValue => {
358            let (points, weights) = build_quadrature_points(prior, quadrature)?;
359            let mut acc = 0.0;
360            for (p, w) in points.iter().zip(weights.iter()) {
361                acc += w * loss(params, *p);
362            }
363            Ok(acc)
364        }
365        AreaScalarisation::Cvar { alpha } => {
366            if !(0.0..=1.0).contains(&alpha) || alpha <= 0.0 {
367                return Err(AreaError::InvalidQuadrature(format!(
368                    "CVaR alpha must be in (0, 1], got {}",
369                    alpha
370                )));
371            }
372            let (points, weights) = build_quadrature_points(prior, quadrature)?;
373            // Compute losses and pair with importance weights.
374            let mut wl: Vec<(f64, f64)> = points
375                .iter()
376                .zip(weights.iter())
377                .map(|(p, &w)| (loss(params, *p), w))
378                .collect();
379            // Worst losses first.
380            wl.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
381            // Walk down the sorted list, accumulating mass until α is reached.
382            let mut acc_loss = 0.0;
383            let mut acc_mass = 0.0;
384            for (l, w) in &wl {
385                let take = (alpha - acc_mass).min(*w);
386                if take <= 0.0 {
387                    break;
388                }
389                acc_loss += take * l;
390                acc_mass += take;
391                if acc_mass >= alpha {
392                    break;
393                }
394            }
395            if acc_mass <= 0.0 {
396                return Err(AreaError::InvalidQuadrature(
397                    "CVaR encountered zero total importance weight".into(),
398                ));
399            }
400            Ok(acc_loss / acc_mass)
401        }
402    }
403}
404
405// ============================================================================
406// Internals
407// ============================================================================
408
409fn sobol_unit<const D: usize>(num_points: usize, _seed: u64) -> Vec<[f64; D]> {
410    // Reuse the existing `init_sobol` implementation but re-package per-point.
411    // `init_sobol` returns `Vec<Vec<f64>>` over user-supplied bounds; we
412    // request `[0,1]^D` and copy into a stack array per point.
413    let unit_bounds: Vec<(f64, f64)> = (0..D).map(|_| (0.0, 1.0)).collect();
414    let raw = crate::init_sobol::init_halton(D, num_points, &unit_bounds);
415    raw.into_iter()
416        .map(|v| {
417            let mut out = [0.0_f64; D];
418            for (i, x) in v.into_iter().enumerate().take(D) {
419                out[i] = x;
420            }
421            out
422        })
423        .collect()
424}
425
426fn latin_hypercube_unit<const D: usize>(num_points: usize, seed: u64) -> Vec<[f64; D]> {
427    let lower = Array1::<f64>::zeros(D);
428    let upper = Array1::<f64>::ones(D);
429    let is_free = vec![true; D];
430    let mut rng = StdRng::seed_from_u64(seed);
431    let m: Array2<f64> = init_latin_hypercube(D, num_points, &lower, &upper, &is_free, &mut rng);
432    (0..num_points)
433        .map(|row| {
434            let mut out = [0.0_f64; D];
435            for col in 0..D {
436                out[col] = m[(row, col)];
437            }
438            out
439        })
440        .collect()
441}
442
443fn transform_unit_samples<const D: usize>(
444    raw: &[[f64; D]],
445    prior: &Prior<D>,
446    bounds: &[(f64, f64); D],
447) -> Result<(Vec<[f64; D]>, Vec<f64>), AreaError> {
448    let n = raw.len();
449    let uniform_weight = 1.0 / n as f64;
450
451    match prior {
452        Prior::Uniform { .. } => {
453            let pts: Vec<[f64; D]> = raw
454                .iter()
455                .map(|u| {
456                    let mut out = [0.0_f64; D];
457                    for i in 0..D {
458                        out[i] = bounds[i].0 + u[i] * (bounds[i].1 - bounds[i].0);
459                    }
460                    out
461                })
462                .collect();
463            Ok((pts, vec![uniform_weight; n]))
464        }
465        Prior::Gaussian { mean, cov_diag, .. } => {
466            // Inverse-CDF transform: u → Φ⁻¹(u). Then scale and shift.
467            // The bounding-box clamp from `bounding_box()` already enforces
468            // the truncation; remap u to [u_lo, u_hi] before inverse-CDF.
469            let mut pts: Vec<[f64; D]> = Vec::with_capacity(n);
470            for u in raw {
471                let mut out = [0.0_f64; D];
472                for i in 0..D {
473                    let sigma = cov_diag[i].sqrt();
474                    // Truncation bounds in standardised units:
475                    let z_lo = (bounds[i].0 - mean[i]) / sigma;
476                    let z_hi = (bounds[i].1 - mean[i]) / sigma;
477                    let p_lo = standard_normal_cdf(z_lo);
478                    let p_hi = standard_normal_cdf(z_hi);
479                    let u_remap = p_lo + u[i] * (p_hi - p_lo);
480                    let z = inv_standard_normal(u_remap);
481                    out[i] = mean[i] + sigma * z;
482                }
483                pts.push(out);
484            }
485            Ok((pts, vec![uniform_weight; n]))
486        }
487        Prior::Custom { density, .. } => {
488            // Uniform sampling on bounding box, importance-weighted by density.
489            let pts: Vec<[f64; D]> = raw
490                .iter()
491                .map(|u| {
492                    let mut out = [0.0_f64; D];
493                    for i in 0..D {
494                        out[i] = bounds[i].0 + u[i] * (bounds[i].1 - bounds[i].0);
495                    }
496                    out
497                })
498                .collect();
499            let mut weights: Vec<f64> = pts.iter().map(|p| density(*p).max(0.0)).collect();
500            let total: f64 = weights.iter().sum();
501            if total <= 0.0 {
502                return Err(AreaError::InvalidPrior(
503                    "Custom density evaluated to zero on every sampled point".into(),
504                ));
505            }
506            for w in weights.iter_mut() {
507                *w /= total;
508            }
509            Ok((pts, weights))
510        }
511    }
512}
513
514fn gauss_legendre_tensor<const D: usize>(
515    points_per_axis: usize,
516    bounds: &[(f64, f64); D],
517) -> (Vec<[f64; D]>, Vec<f64>) {
518    let (nodes_unit, weights_unit) = gauss_legendre_1d(points_per_axis);
519    // Rescale per axis: x = 0.5*(hi+lo) + 0.5*(hi-lo)*ξ, w = 0.5*(hi-lo)*w_ξ.
520    // Joint weights are products of per-axis weights, divided by the box
521    // volume to make them sum to 1 (probability-weighted under uniform prior).
522    let mut nodes_per_axis: [Vec<f64>; D] = std::array::from_fn(|_| Vec::new());
523    let mut weights_per_axis: [Vec<f64>; D] = std::array::from_fn(|_| Vec::new());
524    for i in 0..D {
525        let (lo, hi) = bounds[i];
526        let mid = 0.5 * (hi + lo);
527        let half = 0.5 * (hi - lo);
528        let mut nodes = Vec::with_capacity(points_per_axis);
529        let mut weights = Vec::with_capacity(points_per_axis);
530        for k in 0..points_per_axis {
531            nodes.push(mid + half * nodes_unit[k]);
532            weights.push(half * weights_unit[k]);
533        }
534        nodes_per_axis[i] = nodes;
535        weights_per_axis[i] = weights;
536    }
537
538    let total: usize = points_per_axis.pow(D as u32);
539    let mut pts: Vec<[f64; D]> = Vec::with_capacity(total);
540    let mut wts: Vec<f64> = Vec::with_capacity(total);
541    for idx in 0..total {
542        let mut pt = [0.0_f64; D];
543        let mut w = 1.0_f64;
544        let mut k = idx;
545        for i in 0..D {
546            let ki = k % points_per_axis;
547            k /= points_per_axis;
548            pt[i] = nodes_per_axis[i][ki];
549            w *= weights_per_axis[i][ki];
550        }
551        pts.push(pt);
552        wts.push(w);
553    }
554
555    // Normalise weights so they sum to 1 (probability-weighted under uniform prior).
556    let total_w: f64 = wts.iter().sum();
557    if total_w > 0.0 {
558        for w in wts.iter_mut() {
559            *w /= total_w;
560        }
561    }
562
563    (pts, wts)
564}
565
566/// Gauss–Legendre nodes and weights on `[-1, 1]`.
567///
568/// Hand-tabulated up to n=8; for higher orders, computes via Newton iteration
569/// on the Legendre polynomial recurrence. Up to n=8 covers polynomial degrees
570/// up to 15 exactly along each axis, which is plenty for typical RoomEQ
571/// listening-area integrands.
572fn gauss_legendre_1d(n: usize) -> (Vec<f64>, Vec<f64>) {
573    if n == 0 {
574        return (Vec::new(), Vec::new());
575    }
576    if n == 1 {
577        return (vec![0.0], vec![2.0]);
578    }
579
580    // Newton iteration on roots of Legendre polynomial P_n.
581    let mut nodes = vec![0.0_f64; n];
582    let mut weights = vec![0.0_f64; n];
583    for i in 0..n {
584        // Initial guess via Tricomi / Chebyshev approximation.
585        let mut x = (std::f64::consts::PI * (i as f64 + 0.75) / (n as f64 + 0.5)).cos();
586        for _ in 0..50 {
587            // Recurrence: P_0 = 1, P_1 = x, (k+1) P_{k+1} = (2k+1) x P_k - k P_{k-1}.
588            let mut p_prev2 = 1.0_f64;
589            let mut p_prev1 = x;
590            for k in 1..n {
591                let p_next =
592                    ((2.0 * k as f64 + 1.0) * x * p_prev1 - k as f64 * p_prev2) / (k as f64 + 1.0);
593                p_prev2 = p_prev1;
594                p_prev1 = p_next;
595            }
596            // P_n = p_prev1; derivative P'_n = n*(x*P_n - P_{n-1}) / (x^2 - 1)
597            let p_n = p_prev1;
598            let dp_n = n as f64 * (x * p_n - p_prev2) / (x * x - 1.0);
599            let dx = p_n / dp_n;
600            x -= dx;
601            if dx.abs() < 1e-15 {
602                break;
603            }
604        }
605        // Recompute P_{n-1} at converged x for the weight formula.
606        let mut p_prev2 = 1.0_f64;
607        let mut p_prev1 = x;
608        for k in 1..n {
609            let p_next =
610                ((2.0 * k as f64 + 1.0) * x * p_prev1 - k as f64 * p_prev2) / (k as f64 + 1.0);
611            p_prev2 = p_prev1;
612            p_prev1 = p_next;
613        }
614        let p_n = p_prev1;
615        let dp_n = n as f64 * (x * p_n - p_prev2) / (x * x - 1.0);
616        nodes[i] = x;
617        weights[i] = 2.0 / ((1.0 - x * x) * dp_n * dp_n);
618    }
619
620    // Sort nodes ascending so output is canonical.
621    let mut idx: Vec<usize> = (0..n).collect();
622    idx.sort_by(|&a, &b| {
623        nodes[a]
624            .partial_cmp(&nodes[b])
625            .unwrap_or(std::cmp::Ordering::Equal)
626    });
627    let nodes_sorted: Vec<f64> = idx.iter().map(|&i| nodes[i]).collect();
628    let weights_sorted: Vec<f64> = idx.iter().map(|&i| weights[i]).collect();
629    (nodes_sorted, weights_sorted)
630}
631
632fn worst_case_via_de<F, const D: usize>(
633    loss: &F,
634    params: &[f64],
635    prior: &Prior<D>,
636    inner_maxiter: usize,
637    inner_seed: u64,
638) -> Result<f64, AreaError>
639where
640    F: Fn(&[f64], [f64; D]) -> f64 + Sync,
641{
642    prior.validate()?;
643    let bounds_arr = prior.bounding_box();
644    let bounds_vec: Vec<(f64, f64)> = bounds_arr.iter().copied().collect();
645
646    // Negate the loss so DE (minimiser) finds the maximiser.
647    let neg_loss = |p_vec: &Array1<f64>| -> f64 {
648        let mut p = [0.0_f64; D];
649        for i in 0..D {
650            p[i] = p_vec[i];
651        }
652        -loss(params, p)
653    };
654
655    let cfg = DEConfigBuilder::new()
656        .maxiter(inner_maxiter.max(5))
657        .popsize(8)
658        .seed(inner_seed)
659        .build()
660        .map_err(|e| AreaError::InnerSearchFailed(format!("{e}")))?;
661
662    let report = differential_evolution(&neg_loss, &bounds_vec, cfg)
663        .map_err(|e| AreaError::InnerSearchFailed(format!("{e}")))?;
664
665    Ok(-report.fun)
666}
667
668fn standard_normal_cdf(x: f64) -> f64 {
669    // 0.5 * (1 + erf(x / sqrt(2)))
670    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
671}
672
673#[allow(clippy::excessive_precision)]
674fn inv_standard_normal(u: f64) -> f64 {
675    // Beasley-Springer-Moro inverse normal CDF approximation (good to ~1e-7).
676    // Constants are reference values from the published paper — keeping full
677    // precision matters for round-trip accuracy at α/2 quantiles.
678    // Clamp to (0, 1) to avoid ±inf at the tails.
679    let u = u.clamp(1e-12, 1.0 - 1e-12);
680    let a = [
681        -3.969683028665376e+01,
682        2.209460984245205e+02,
683        -2.759285104469687e+02,
684        1.383577518672690e+02,
685        -3.066479806614716e+01,
686        2.506628277459239e+00,
687    ];
688    let b = [
689        -5.447609879822406e+01,
690        1.615858368580409e+02,
691        -1.556989798598866e+02,
692        6.680131188771972e+01,
693        -1.328068155288572e+01,
694    ];
695    let c = [
696        -7.784894002430293e-03,
697        -3.223964580411365e-01,
698        -2.400758277161838e+00,
699        -2.549732539343734e+00,
700        4.374664141464968e+00,
701        2.938163982698783e+00,
702    ];
703    let d = [
704        7.784695709041462e-03,
705        3.224671290700398e-01,
706        2.445134137142996e+00,
707        3.754408661907416e+00,
708    ];
709
710    let plow = 0.02425;
711    let phigh = 1.0 - plow;
712
713    if u < plow {
714        let q = (-2.0 * u.ln()).sqrt();
715        let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
716        let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
717        num / den
718    } else if u <= phigh {
719        let q = u - 0.5;
720        let r = q * q;
721        (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
722            / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0)
723    } else {
724        let q = (-2.0 * (1.0 - u).ln()).sqrt();
725        let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
726        let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
727        -num / den
728    }
729}
730
731fn erf(x: f64) -> f64 {
732    // Abramowitz & Stegun 7.1.26 — max error ~1.5e-7, plenty for prior transforms.
733    let sign = x.signum();
734    let x = x.abs();
735    let a1 = 0.254829592;
736    let a2 = -0.284496736;
737    let a3 = 1.421413741;
738    let a4 = -1.453152027;
739    let a5 = 1.061405429;
740    let p = 0.3275911;
741    let t = 1.0 / (1.0 + p * x);
742    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
743    sign * y
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    #[test]
751    fn sobol_uniform_integrates_p_squared() {
752        // ∫_0^1 p^2 dp = 1/3
753        let prior: Prior<1> = Prior::Uniform {
754            bounds: [(0.0, 1.0)],
755        };
756        let q: Quadrature<1> = Quadrature::Sobol {
757            num_points: 1024,
758            seed: 0,
759        };
760        let loss = |_p: &[f64], pt: [f64; 1]| pt[0] * pt[0];
761        let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
762        assert!((v - 1.0 / 3.0).abs() < 1e-2, "got {}", v);
763    }
764
765    #[test]
766    fn lhs_uniform_2d_integrates_constant_to_constant() {
767        let prior: Prior<2> = Prior::Uniform {
768            bounds: [(0.0, 2.0), (-1.0, 3.0)],
769        };
770        let q: Quadrature<2> = Quadrature::LatinHypercube {
771            num_points: 256,
772            seed: 7,
773        };
774        let loss = |_p: &[f64], _pt: [f64; 2]| 5.5;
775        let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
776        assert!((v - 5.5).abs() < 1e-9, "got {}", v);
777    }
778
779    #[test]
780    fn gauss_legendre_exactness_polynomial_degree_three() {
781        // ∫_{-1}^{1} (3p^3 - 2p^2 + p) dp = -4/3 (only the p^2 term survives)
782        // GL-2 is exact on degree 3.
783        let prior: Prior<1> = Prior::Uniform {
784            bounds: [(-1.0, 1.0)],
785        };
786        let q: Quadrature<1> = Quadrature::GaussLegendre { points_per_axis: 2 };
787        let loss = |_p: &[f64], pt: [f64; 1]| 3.0 * pt[0].powi(3) - 2.0 * pt[0].powi(2) + pt[0];
788        let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
789        // Probability-weighted: integral / volume(2) = -4/3 / 2 = -2/3
790        assert!((v - (-2.0 / 3.0)).abs() < 1e-9, "got {}", v);
791    }
792
793    #[test]
794    fn worst_case_finds_known_max() {
795        // L(x, p) = -(p - 0.4)^2  on p ∈ [0, 1]: max at p=0.4 → loss=0.
796        let prior: Prior<1> = Prior::Uniform {
797            bounds: [(0.0, 1.0)],
798        };
799        let q: Quadrature<1> = Quadrature::Sobol {
800            num_points: 16,
801            seed: 0,
802        };
803        let loss = |_p: &[f64], pt: [f64; 1]| -(pt[0] - 0.4).powi(2);
804        let v = evaluate_area_loss(
805            &loss,
806            &[0.0],
807            &prior,
808            &q,
809            AreaScalarisation::WorstCase {
810                inner_maxiter: 60,
811                inner_seed: 1,
812            },
813        );
814        assert!(v > -1e-3, "expected ~0, got {}", v);
815    }
816
817    #[test]
818    fn gaussian_prior_expected_value_matches_known_mean() {
819        // E[(p - 0)^2] for p ~ N(1, 0.25) is mean^2 + variance = 1 + 0.25 = 1.25
820        let prior: Prior<1> = Prior::Gaussian {
821            mean: [1.0],
822            cov_diag: [0.25],
823            truncation_sigmas: 5.0,
824        };
825        let q: Quadrature<1> = Quadrature::Sobol {
826            num_points: 4096,
827            seed: 0,
828        };
829        let loss = |_p: &[f64], pt: [f64; 1]| pt[0] * pt[0];
830        let v = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
831        assert!((v - 1.25).abs() < 5e-2, "got {}", v);
832    }
833
834    #[test]
835    fn cvar_concentrates_on_tail() {
836        // ExpectedValue of a flat-bottom-with-corner-spike loss should be modest;
837        // CVaR(α=0.1) should be much higher because it averages the worst 10 %.
838        let prior: Prior<1> = Prior::Uniform {
839            bounds: [(0.0, 1.0)],
840        };
841        let q: Quadrature<1> = Quadrature::Sobol {
842            num_points: 1024,
843            seed: 0,
844        };
845        let loss = |_p: &[f64], pt: [f64; 1]| if pt[0] > 0.9 { 100.0 } else { 1.0 };
846        let mean = evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue);
847        let cvar = evaluate_area_loss(
848            &loss,
849            &[0.0],
850            &prior,
851            &q,
852            AreaScalarisation::Cvar { alpha: 0.1 },
853        );
854        assert!(
855            cvar > mean * 5.0,
856            "cvar {} should be >> mean {}",
857            cvar,
858            mean
859        );
860    }
861
862    #[test]
863    fn rejects_zero_quadrature_points() {
864        let prior: Prior<1> = Prior::Uniform {
865            bounds: [(0.0, 1.0)],
866        };
867        let q: Quadrature<1> = Quadrature::Sobol {
868            num_points: 0,
869            seed: 0,
870        };
871        let loss = |_p: &[f64], _pt: [f64; 1]| 1.0;
872        assert!(
873            try_evaluate_area_loss(&loss, &[0.0], &prior, &q, AreaScalarisation::ExpectedValue)
874                .is_err()
875        );
876    }
877
878    #[test]
879    fn rejects_degenerate_uniform_bounds() {
880        let prior: Prior<1> = Prior::Uniform {
881            bounds: [(1.0, 1.0)],
882        };
883        assert!(prior.validate().is_err());
884    }
885
886    #[test]
887    fn gauss_legendre_1d_nodes_symmetric() {
888        for n in 2..=6 {
889            let (nodes, weights) = gauss_legendre_1d(n);
890            assert_eq!(nodes.len(), n);
891            assert_eq!(weights.len(), n);
892            let total_w: f64 = weights.iter().sum();
893            assert!(
894                (total_w - 2.0).abs() < 1e-10,
895                "n={}: total_w={}",
896                n,
897                total_w
898            );
899            // Symmetry around 0:
900            for i in 0..n / 2 {
901                assert!(
902                    (nodes[i] + nodes[n - 1 - i]).abs() < 1e-10,
903                    "n={}, i={}: nodes={:?}",
904                    n,
905                    i,
906                    nodes
907                );
908                assert!(
909                    (weights[i] - weights[n - 1 - i]).abs() < 1e-10,
910                    "n={}, i={}: weights={:?}",
911                    n,
912                    i,
913                    weights
914                );
915            }
916        }
917    }
918
919    #[test]
920    fn standard_normal_cdf_known_values() {
921        assert!((standard_normal_cdf(0.0) - 0.5).abs() < 1e-6);
922        assert!((standard_normal_cdf(1.0) - 0.8413447).abs() < 1e-4);
923        assert!((standard_normal_cdf(-1.0) - 0.1586553).abs() < 1e-4);
924    }
925
926    #[test]
927    fn inv_standard_normal_round_trip() {
928        for &p in &[0.05_f64, 0.25, 0.5, 0.75, 0.95] {
929            let z = inv_standard_normal(p);
930            let p2 = standard_normal_cdf(z);
931            assert!((p - p2).abs() < 1e-3, "p={}, z={}, p2={}", p, z, p2);
932        }
933    }
934}