Skip to main content

scirs2_stats/inla/
hyperparameters.rs

1//! Hyperparameter exploration and integration for INLA
2//!
3//! This module handles:
4//! - Evaluating the log-posterior of hyperparameters via Laplace approximation
5//! - Grid-based exploration of the hyperparameter space
6//! - Central Composite Design (CCD) integration points
7//! - Numerical integration on the log scale
8
9use scirs2_core::ndarray::{Array1, Array2};
10
11use super::laplace;
12use super::types::{
13    HyperparameterPosterior, INLAConfig, IntegrationStrategy, LatentGaussianModel, LikelihoodFamily,
14};
15use crate::error::StatsError;
16
17/// A single hyperparameter configuration with its log-posterior value
18#[derive(Debug, Clone)]
19pub struct HyperparameterPoint {
20    /// Hyperparameter values (e.g., log-precision for Gaussian, etc.)
21    pub theta: Vec<f64>,
22    /// Log-posterior density p̃(θ|y) evaluated via Laplace approximation
23    pub log_posterior: f64,
24    /// The mode result at this hyperparameter configuration
25    pub mode: Array1<f64>,
26    /// Diagonal of the inverse negative Hessian (marginal variances at this θ)
27    pub marginal_variances: Array1<f64>,
28}
29
30/// Evaluate the log-posterior of hyperparameters using Laplace approximation
31///
32/// log p̃(θ|y) ∝ log p(y|θ) + log p(θ)
33///
34/// where log p(y|θ) is approximated using the Laplace method.
35///
36/// # Arguments
37/// * `theta` - Hyperparameter value (log-precision scale)
38/// * `model` - The latent Gaussian model
39/// * `config` - INLA configuration
40///
41/// # Returns
42/// A `HyperparameterPoint` containing the log-posterior and associated quantities
43pub fn evaluate_hyperparameter(
44    theta: f64,
45    model: &LatentGaussianModel,
46    config: &INLAConfig,
47) -> Result<HyperparameterPoint, StatsError> {
48    // Scale the precision matrix by exp(theta) (theta is log-precision)
49    let scale = theta.exp();
50    let scaled_precision = &model.precision_matrix * scale;
51
52    // Find the posterior mode at this hyperparameter value
53    let mode_result = laplace::find_mode(
54        &scaled_precision,
55        &model.y,
56        &model.design_matrix,
57        model.likelihood,
58        model.n_trials.as_ref(),
59        model.observation_precision,
60        config.max_newton_iter,
61        config.newton_tol,
62        config.newton_damping,
63    )?;
64
65    // Compute Laplace approximation to log p(y|θ)
66    let log_marginal = laplace::laplace_log_marginal_likelihood(&mode_result, &scaled_precision)?;
67
68    // Log prior on θ (flat/improper prior by default)
69    let log_prior_theta = log_hyperprior(theta, config);
70
71    // Compute marginal variances (diagonal of H^{-1})
72    let marginal_vars = laplace::inverse_diagonal(&mode_result.neg_hessian)?;
73
74    Ok(HyperparameterPoint {
75        theta: vec![theta],
76        log_posterior: log_marginal + log_prior_theta,
77        mode: mode_result.mode,
78        marginal_variances: marginal_vars,
79    })
80}
81
82/// Log-prior for hyperparameters
83///
84/// Uses a flat prior by default, or a Gaussian prior if range is specified.
85fn log_hyperprior(theta: f64, config: &INLAConfig) -> f64 {
86    match config.hyperparameter_range {
87        Some((lo, hi)) => {
88            // Penalized complexity prior: log-Gaussian centered at midpoint
89            let mid = (lo + hi) / 2.0;
90            let scale = (hi - lo) / 4.0; // 95% within range
91            if scale <= 0.0 {
92                return 0.0;
93            }
94            -0.5 * ((theta - mid) / scale).powi(2)
95        }
96        None => 0.0, // flat (improper) prior
97    }
98}
99
100/// Explore the hyperparameter space on a grid
101///
102/// Creates a 1D grid of hyperparameter values and evaluates the
103/// Laplace-approximated log-posterior at each point.
104///
105/// # Arguments
106/// * `model` - The latent Gaussian model
107/// * `config` - INLA configuration
108///
109/// # Returns
110/// Vector of `HyperparameterPoint` sorted by log-posterior (descending)
111pub fn explore_hyperparameter_grid(
112    model: &LatentGaussianModel,
113    config: &INLAConfig,
114) -> Result<Vec<HyperparameterPoint>, StatsError> {
115    let n_grid = config.n_hyperparameter_grid;
116    if n_grid == 0 {
117        return Err(StatsError::InvalidArgument(
118            "Number of hyperparameter grid points must be positive".to_string(),
119        ));
120    }
121
122    // Determine grid range
123    let (lo, hi) = config.hyperparameter_range.unwrap_or((-3.0, 3.0));
124
125    let grid_points = create_grid(lo, hi, n_grid);
126
127    let mut results = Vec::with_capacity(n_grid);
128    for &theta in &grid_points {
129        match evaluate_hyperparameter(theta, model, config) {
130            Ok(point) => results.push(point),
131            Err(_) => {
132                // Skip points where mode finding fails (e.g., numerical issues)
133                continue;
134            }
135        }
136    }
137
138    if results.is_empty() {
139        return Err(StatsError::ConvergenceError(
140            "INLA failed to evaluate any hyperparameter grid point".to_string(),
141        ));
142    }
143
144    // Sort by log-posterior (descending)
145    results.sort_by(|a, b| {
146        b.log_posterior
147            .partial_cmp(&a.log_posterior)
148            .unwrap_or(std::cmp::Ordering::Equal)
149    });
150
151    Ok(results)
152}
153
154/// Create a uniform grid of points in [lo, hi]
155fn create_grid(lo: f64, hi: f64, n: usize) -> Vec<f64> {
156    if n == 1 {
157        return vec![(lo + hi) / 2.0];
158    }
159    let step = (hi - lo) / (n - 1) as f64;
160    (0..n).map(|i| lo + i as f64 * step).collect()
161}
162
163/// Generate Central Composite Design (CCD) integration points
164///
165/// CCD is more efficient than a full grid for multivariate hyperparameter
166/// integration. For d hyperparameters, it uses:
167/// - 1 center point
168/// - 2*d axial points at distance ±α from center
169/// - 2^d factorial points (for d ≤ 4, else a fraction)
170///
171/// Total points: 1 + 2*d + min(2^d, 2*d) for large d
172///
173/// # Arguments
174/// * `n_hyperparams` - Number of hyperparameters
175///
176/// # Returns
177/// Vector of point coordinates (each is a `Vec<f64>` of length n_hyperparams)
178/// The points are on a standardized scale (centered at 0, scaled by 1).
179pub fn ccd_integration_points(n_hyperparams: usize) -> Result<Vec<Vec<f64>>, StatsError> {
180    if n_hyperparams == 0 {
181        return Err(StatsError::InvalidArgument(
182            "Number of hyperparameters must be positive".to_string(),
183        ));
184    }
185
186    let mut points = Vec::new();
187
188    // Center point
189    points.push(vec![0.0; n_hyperparams]);
190
191    // Axial distance: alpha = sqrt(n_hyperparams) for rotatability
192    let alpha = (n_hyperparams as f64).sqrt();
193
194    // Axial points: ±alpha along each axis
195    for d in 0..n_hyperparams {
196        let mut point_pos = vec![0.0; n_hyperparams];
197        point_pos[d] = alpha;
198        points.push(point_pos);
199
200        let mut point_neg = vec![0.0; n_hyperparams];
201        point_neg[d] = -alpha;
202        points.push(point_neg);
203    }
204
205    // Factorial points: all combinations of ±1
206    // For large d, use fractional factorial
207    let max_factorial = if n_hyperparams <= 6 {
208        1usize << n_hyperparams // 2^d
209    } else {
210        // Fractional factorial for high dimensions
211        2 * n_hyperparams
212    };
213
214    let n_factorial = (1usize << n_hyperparams).min(max_factorial);
215    for i in 0..n_factorial {
216        let mut point = vec![0.0; n_hyperparams];
217        for d in 0..n_hyperparams {
218            point[d] = if (i >> d) & 1 == 0 { -1.0 } else { 1.0 };
219        }
220        points.push(point);
221    }
222
223    Ok(points)
224}
225
226/// Perform numerical integration on the log scale
227///
228/// Given log-densities at grid points, compute the normalized weights
229/// and the log of the normalizing constant.
230///
231/// Uses the log-sum-exp trick for numerical stability.
232///
233/// # Arguments
234/// * `log_densities` - Log-density values at grid points
235/// * `grid_spacing` - Spacing between grid points (for trapezoidal rule)
236///
237/// # Returns
238/// Tuple of (normalized_weights, log_normalizing_constant)
239pub fn grid_integration(
240    log_densities: &[f64],
241    grid_spacing: f64,
242) -> Result<(Vec<f64>, f64), StatsError> {
243    if log_densities.is_empty() {
244        return Err(StatsError::InvalidArgument(
245            "Log densities array is empty".to_string(),
246        ));
247    }
248
249    // Find maximum for log-sum-exp trick
250    let max_log = log_densities
251        .iter()
252        .copied()
253        .fold(f64::NEG_INFINITY, f64::max);
254
255    if max_log.is_infinite() && max_log < 0.0 {
256        return Err(StatsError::ComputationError(
257            "All log densities are -infinity".to_string(),
258        ));
259    }
260
261    // Compute weights using trapezoidal rule on log scale
262    let n = log_densities.len();
263    let mut weights = Vec::with_capacity(n);
264    for i in 0..n {
265        let trap_factor = if i == 0 || i == n - 1 { 0.5 } else { 1.0 };
266        weights.push((log_densities[i] - max_log).exp() * trap_factor * grid_spacing);
267    }
268
269    let total_weight: f64 = weights.iter().sum();
270    if total_weight <= 0.0 {
271        return Err(StatsError::ComputationError(
272            "Total integration weight is non-positive".to_string(),
273        ));
274    }
275
276    let log_normalizing = max_log + total_weight.ln();
277
278    // Normalize weights
279    let normalized: Vec<f64> = weights.iter().map(|w| w / total_weight).collect();
280
281    Ok((normalized, log_normalizing))
282}
283
284/// Compute posterior summary statistics for a hyperparameter from grid evaluation
285///
286/// # Arguments
287/// * `grid_points` - Grid point values
288/// * `log_densities` - Log-density at each grid point
289/// * `grid_spacing` - Spacing between grid points
290///
291/// # Returns
292/// `HyperparameterPosterior` with mean, variance, and density information
293pub fn summarize_hyperparameter_posterior(
294    grid_points: &[f64],
295    log_densities: &[f64],
296    grid_spacing: f64,
297) -> Result<HyperparameterPosterior, StatsError> {
298    if grid_points.len() != log_densities.len() {
299        return Err(StatsError::DimensionMismatch(
300            "Grid points and log densities must have the same length".to_string(),
301        ));
302    }
303
304    let (weights, _) = grid_integration(log_densities, grid_spacing)?;
305
306    // Compute mean: E[θ] = Σ w_i * θ_i
307    let mean: f64 = weights
308        .iter()
309        .zip(grid_points.iter())
310        .map(|(w, t)| w * t)
311        .sum();
312
313    // Compute variance: Var[θ] = Σ w_i * (θ_i - mean)^2
314    let variance: f64 = weights
315        .iter()
316        .zip(grid_points.iter())
317        .map(|(w, t)| w * (t - mean).powi(2))
318        .sum();
319
320    Ok(HyperparameterPosterior {
321        grid_points: grid_points.to_vec(),
322        log_densities: log_densities.to_vec(),
323        mean,
324        variance,
325    })
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use scirs2_core::ndarray::{array, Array2};
332
333    #[test]
334    fn test_create_grid() {
335        let grid = create_grid(-1.0, 1.0, 5);
336        assert_eq!(grid.len(), 5);
337        assert!((grid[0] - (-1.0)).abs() < 1e-10);
338        assert!((grid[4] - 1.0).abs() < 1e-10);
339        assert!((grid[2] - 0.0).abs() < 1e-10);
340    }
341
342    #[test]
343    fn test_create_grid_single() {
344        let grid = create_grid(-1.0, 1.0, 1);
345        assert_eq!(grid.len(), 1);
346        assert!((grid[0] - 0.0).abs() < 1e-10);
347    }
348
349    #[test]
350    fn test_ccd_1d() {
351        let points = ccd_integration_points(1).expect("CCD should succeed for 1D");
352        // 1D: 1 center + 2 axial + 2 factorial = 5
353        assert_eq!(points.len(), 5);
354        // Center point
355        assert!((points[0][0]).abs() < 1e-10);
356        // Axial points at ±1
357        assert!((points[1][0] - 1.0).abs() < 1e-10);
358        assert!((points[2][0] - (-1.0)).abs() < 1e-10);
359    }
360
361    #[test]
362    fn test_ccd_2d() {
363        let points = ccd_integration_points(2).expect("CCD should succeed for 2D");
364        // 2D: 1 center + 4 axial + 4 factorial = 9
365        assert_eq!(points.len(), 9);
366        // Center
367        assert!((points[0][0]).abs() < 1e-10);
368        assert!((points[0][1]).abs() < 1e-10);
369    }
370
371    #[test]
372    fn test_ccd_3d() {
373        let points = ccd_integration_points(3).expect("CCD should succeed for 3D");
374        // 3D: 1 center + 6 axial + 8 factorial = 15
375        assert_eq!(points.len(), 15);
376    }
377
378    #[test]
379    fn test_ccd_zero() {
380        let result = ccd_integration_points(0);
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_grid_integration_uniform() {
386        // Uniform log-densities should give equal weights
387        let log_densities = vec![0.0, 0.0, 0.0, 0.0, 0.0];
388        let (weights, _) =
389            grid_integration(&log_densities, 1.0).expect("Integration should succeed");
390        // Middle points get weight 1, endpoints get weight 0.5, total = 4
391        // So normalized: 0.125, 0.25, 0.25, 0.25, 0.125
392        assert!((weights[0] - 0.125).abs() < 1e-10);
393        assert!((weights[2] - 0.25).abs() < 1e-10);
394    }
395
396    #[test]
397    fn test_grid_integration_peaked() {
398        // Strongly peaked distribution
399        let log_densities = vec![-100.0, -10.0, 0.0, -10.0, -100.0];
400        let (weights, _) =
401            grid_integration(&log_densities, 1.0).expect("Integration should succeed");
402        // Most weight should be on the center point
403        assert!(weights[2] > 0.9);
404    }
405
406    #[test]
407    fn test_grid_integration_empty() {
408        let result = grid_integration(&[], 1.0);
409        assert!(result.is_err());
410    }
411
412    #[test]
413    fn test_summarize_posterior() {
414        // Symmetric around 0 should give mean ≈ 0
415        let grid_points = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
416        let log_densities = vec![-2.0, -0.5, 0.0, -0.5, -2.0];
417        let result = summarize_hyperparameter_posterior(&grid_points, &log_densities, 1.0)
418            .expect("Summary should succeed");
419        assert!(
420            result.mean.abs() < 0.1,
421            "Mean should be near 0, got {}",
422            result.mean
423        );
424        assert!(result.variance > 0.0, "Variance should be positive");
425    }
426
427    #[test]
428    fn test_explore_grid_gaussian() {
429        let n = 3;
430        let y = array![1.0, 2.0, 3.0];
431        let design = Array2::eye(n);
432        let precision = Array2::eye(n);
433
434        let model = LatentGaussianModel::new(y, design, precision, LikelihoodFamily::Gaussian)
435            .with_observation_precision(1.0);
436
437        let config = INLAConfig {
438            n_hyperparameter_grid: 5,
439            hyperparameter_range: Some((-1.0, 1.0)),
440            max_newton_iter: 50,
441            ..INLAConfig::default()
442        };
443
444        let results =
445            explore_hyperparameter_grid(&model, &config).expect("Grid exploration should succeed");
446
447        assert!(!results.is_empty(), "Should have some valid grid points");
448        // Results should be sorted by log-posterior (descending)
449        for i in 1..results.len() {
450            assert!(
451                results[i - 1].log_posterior >= results[i].log_posterior,
452                "Results should be sorted descending"
453            );
454        }
455    }
456
457    #[test]
458    fn test_dimension_mismatch_summary() {
459        let grid = vec![1.0, 2.0];
460        let densities = vec![0.0, 0.0, 0.0];
461        let result = summarize_hyperparameter_posterior(&grid, &densities, 1.0);
462        assert!(result.is_err());
463    }
464}