Skip to main content

blr_core/
gaussian.rs

1//! Multivariate Gaussian distribution N(mean, cov).
2//!
3//! Provides a simple wrapper around a mean vector and a covariance matrix
4//! stored as a row-major flat `Vec<f64>`. This module is used internally by
5//! [`crate::ard`] to represent the posterior weight distribution; most users
6//! interact with it indirectly through [`crate::FittedArd`].
7//!
8//! ## Overview
9//!
10//! A D-dimensional Gaussian `N(μ, Σ)` is constructed with:
11//!
12//! - `mean: Vec<f64>` — length D posterior mean vector μ
13//! - `covariance: Vec<f64>` — length D² covariance matrix Σ (row-major)
14//!
15//! The covariance is stored as a **row-major** D×D flattened vector.
16//! All public methods use `Vec<f64>` / `&[f64]` so callers are not forced
17//! to depend on `faer`.
18//!
19//! ## Example
20//!
21//! ```rust
22//! use blr_core::Gaussian;
23//!
24//! // 2D Gaussian: mean=[1.0, 2.0], covariance=identity
25//! let mean = vec![1.0_f64, 2.0];
26//! let cov  = vec![1.0, 0.0,   // row 0
27//!                 0.0, 1.0];  // row 1
28//! let g = Gaussian::new(mean.clone(), cov).expect("valid 2×2 covariance");
29//! assert_eq!(g.mean, mean);
30//! ```
31
32use faer::linalg::solvers::Solve;
33use faer::{Accum, Mat, Par, Side};
34
35use crate::BLRError;
36
37// ── Helper: Cholesky log-determinant ──────────────────────────────────────────
38
39/// Compute log|A| for a symmetric positive-definite matrix via Cholesky.
40///
41/// Returns `Err(BLRError::SingularMatrix)` if any pivot is non-positive.
42pub(crate) fn cholesky_logdet(mat: &Mat<f64>, d: usize) -> Result<f64, BLRError> {
43    let mut a = mat.clone();
44    for j in 0..d {
45        let mut diag = a[(j, j)];
46        for k in 0..j {
47            let l_jk = a[(j, k)];
48            diag -= l_jk * l_jk;
49        }
50        if diag <= 0.0 {
51            return Err(BLRError::SingularMatrix);
52        }
53        let l_jj = diag.sqrt();
54        a[(j, j)] = l_jj;
55        for i in (j + 1)..d {
56            let mut s = a[(i, j)];
57            for k in 0..j {
58                s -= a[(i, k)] * a[(j, k)];
59            }
60            a[(i, j)] = s / l_jj;
61        }
62    }
63    Ok(2.0 * (0..d).map(|j| a[(j, j)].ln()).sum::<f64>())
64}
65
66pub struct Gaussian {
67    /// Posterior mean — length D.
68    pub mean: Vec<f64>,
69    /// Posterior covariance, row-major D×D.
70    pub cov: Vec<f64>,
71    dim: usize,
72}
73
74impl Gaussian {
75    /// Create a new Gaussian, validating dimensions.
76    pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self, BLRError> {
77        let d = mean.len();
78        if cov.len() != d * d {
79            return Err(BLRError::DimMismatch {
80                expected: d * d,
81                got: cov.len(),
82            });
83        }
84        Ok(Self { mean, cov, dim: d })
85    }
86
87    /// Dimension D of the distribution.
88    pub fn dim(&self) -> usize {
89        self.dim
90    }
91
92    /// Per-dimension standard deviations: `sqrt(diag(cov))`.
93    pub fn std(&self) -> Vec<f64> {
94        let d = self.dim;
95        (0..d).map(|i| self.cov[i * d + i].sqrt()).collect()
96    }
97
98    /// Marginal distribution at index `i`: returns `(mean[i], std[i])`.
99    pub fn marginal(&self, i: usize) -> (f64, f64) {
100        let d = self.dim;
101        (self.mean[i], self.cov[i * d + i].sqrt())
102    }
103
104    /// Log probability density `log N(x; mean, cov)`.
105    ///
106    /// Uses Cholesky of `cov` for numerical stability.
107    pub fn log_pdf(&self, x: &[f64]) -> f64 {
108        let d = self.dim;
109        debug_assert_eq!(x.len(), d);
110
111        let sigma = Mat::<f64>::from_fn(d, d, |i, j| self.cov[i * d + j]);
112        let diff = Mat::<f64>::from_fn(d, 1, |i, _| x[i] - self.mean[i]);
113
114        let llt = sigma
115            .llt(Side::Lower)
116            .expect("Covariance must be positive-definite for log_pdf");
117
118        // Solve L · L^T · z = diff  →  ||z||^2 = diff^T Σ^{-1} diff
119        let z = llt.solve(diff.as_ref());
120        let quadratic: f64 = (0..d)
121            .map(|i| {
122                let v = z[(i, 0)];
123                v * v
124            })
125            .sum();
126
127        // logdet(Σ) via manual Cholesky diagonal (reuse sigma clone)
128        let logdet = cholesky_logdet(&sigma, d).expect("Covariance must be PD");
129
130        -0.5 * quadratic - 0.5 * logdet - (d as f64 / 2.0) * (2.0 * std::f64::consts::PI).ln()
131    }
132
133    /// Bayesian update: computes `p(self | y)` where `y = A·self + ε`,
134    /// `ε ~ N(0, σ²·I_N)` (homoscedastic noise).
135    ///
136    /// `a` is the measurement matrix `A` (n_obs × d_feat), row-major flat slice.
137    /// `noise_variance` is the scalar observation noise variance σ² > 0.
138    ///
139    /// ## Adaptive dispatch
140    ///
141    /// Two algebraically equivalent forms are available; this method selects
142    /// whichever minimises the size of the required Cholesky factorisation:
143    ///
144    /// | Condition | Form chosen | Cholesky size |
145    /// |-----------|-------------|---------------|
146    /// | `n_obs < d_feat`  | Gram / Kalman-gain (observation-space) | N×N |
147    /// | `n_obs >= d_feat` | Precision / Woodbury (parameter-space)  | D×D |
148    ///
149    /// The two forms are related by the Woodbury matrix identity; see
150    /// `dev/blog/blr-and-ard.md` Appendix A for the derivation.
151    /// The precision form derives Σ_prior⁻¹ directly from `self.cov` — no
152    /// isotropic approximation is made, and the forms agree within floating-point
153    /// rounding error.
154    ///
155    /// Returns the updated Gaussian representing the posterior distribution.
156    pub fn condition(
157        self,
158        a: &[f64],
159        n_obs: usize,
160        d_feat: usize,
161        y: &[f64],
162        noise_variance: f64,
163    ) -> Result<Self, BLRError> {
164        debug_assert_eq!(a.len(), n_obs * d_feat);
165        debug_assert_eq!(y.len(), n_obs);
166        debug_assert!(noise_variance > 0.0, "noise_variance must be positive");
167        // Private helpers derive D from self.dim; assert caller agrees (DD-B).
168        debug_assert_eq!(self.dim, d_feat, "d_feat must equal Gaussian dimension");
169
170        if n_obs < d_feat {
171            self.condition_gram_form(a, n_obs, y, noise_variance)
172        } else {
173            self.condition_precision_form(a, n_obs, y, noise_variance)
174        }
175    }
176
177    /// (internal) Observation-space (N×N Gram / Kalman-gain) form of condition().
178    /// Cheaper when n_obs < d_feat (N×N Cholesky vs D×D).
179    fn condition_gram_form(
180        self,
181        a: &[f64],
182        n_obs: usize,
183        y: &[f64],
184        noise_variance: f64,
185    ) -> Result<Self, BLRError> {
186        let d = self.dim;
187
188        let a_mat = Mat::<f64>::from_fn(n_obs, d, |i, j| a[i * d + j]);
189        let mu_mat = Mat::<f64>::from_fn(d, 1, |i, _| self.mean[i]);
190        let sigma_mat = Mat::<f64>::from_fn(d, d, |i, j| self.cov[i * d + j]);
191
192        // Gram = A Σ A^T + σ²·I_N  (N×N)
193        let a_sigma_t = {
194            // A Σ = (N×D) * (D×D) → N×D
195            let mut tmp = Mat::<f64>::zeros(n_obs, d);
196            faer::linalg::matmul::matmul(
197                tmp.as_mut(),
198                Accum::Replace,
199                a_mat.as_ref(),
200                sigma_mat.as_ref(),
201                1.0_f64,
202                Par::Seq,
203            );
204            tmp
205        };
206        let mut gram = {
207            // A_sigma * A^T  (N×D) * (D×N) → N×N
208            // i.e. A Σ A^T
209            let mut tmp = Mat::<f64>::zeros(n_obs, n_obs);
210            faer::linalg::matmul::matmul(
211                tmp.as_mut(),
212                Accum::Replace,
213                a_sigma_t.as_ref(),
214                a_mat.as_ref().transpose(),
215                1.0_f64,
216                Par::Seq,
217            );
218            tmp
219        };
220        // Add σ²·I_N to gram
221        for i in 0..n_obs {
222            gram[(i, i)] += noise_variance;
223        }
224
225        let llt_gram = gram
226            .llt(Side::Lower)
227            .map_err(|_| BLRError::SingularMatrix)?;
228
229        // sigma_at = Σ A^T  (D×N)
230        let sigma_at = {
231            let mut tmp = Mat::<f64>::zeros(d, n_obs);
232            faer::linalg::matmul::matmul(
233                tmp.as_mut(),
234                Accum::Replace,
235                sigma_mat.as_ref(),
236                a_mat.as_ref().transpose(),
237                1.0_f64,
238                Par::Seq,
239            );
240            tmp
241        };
242
243        // residual = y - A μ  (N)
244        let a_mu = {
245            let mut tmp = Mat::<f64>::zeros(n_obs, 1);
246            faer::linalg::matmul::matmul(
247                tmp.as_mut(),
248                Accum::Replace,
249                a_mat.as_ref(),
250                mu_mat.as_ref(),
251                1.0_f64,
252                Par::Seq,
253            );
254            tmp
255        };
256        let residual_mat = Mat::<f64>::from_fn(n_obs, 1, |i, _| y[i] - a_mu[(i, 0)]);
257
258        // Solve Gram * Z = sigma_at^T  →  Z is N×D
259        let z = llt_gram.solve(sigma_at.as_ref().transpose());
260
261        // mu' = mu + sigma_at * Gram^{-1} * residual = mu + Z^T * residual
262        let delta_mu = {
263            let mut tmp = Mat::<f64>::zeros(d, 1);
264            faer::linalg::matmul::matmul(
265                tmp.as_mut(),
266                Accum::Replace,
267                z.as_ref().transpose(),
268                residual_mat.as_ref(),
269                1.0_f64,
270                Par::Seq,
271            );
272            tmp
273        };
274
275        // Sigma' = Sigma - sigma_at * Z  (D×D)
276        let mut sigma_new_mat = sigma_mat.clone();
277        faer::linalg::matmul::matmul(
278            sigma_new_mat.as_mut(),
279            Accum::Add,
280            sigma_at.as_ref(),
281            z.as_ref(),
282            -1.0_f64,
283            Par::Seq,
284        );
285
286        let sigma_new_ref = sigma_new_mat.as_ref();
287        let new_mean: Vec<f64> = (0..d).map(|i| self.mean[i] + delta_mu[(i, 0)]).collect();
288        let new_cov: Vec<f64> = (0..d)
289            .flat_map(|i| (0..d).map(move |j| sigma_new_ref[(i, j)]))
290            .collect();
291
292        Gaussian::new(new_mean, new_cov)
293    }
294
295    /// (internal) Parameter-space (D×D precision / Woodbury) form of condition().
296    /// Cheaper when n_obs >= d_feat (D×D Cholesky vs N×N).
297    /// Derives Σ_prior⁻¹ exactly from self.cov — algebraically equivalent to
298    /// condition_gram_form() via the Woodbury matrix identity.
299    fn condition_precision_form(
300        self,
301        a: &[f64],
302        n_obs: usize,
303        y: &[f64],
304        noise_variance: f64,
305    ) -> Result<Self, BLRError> {
306        let d = self.dim; // canonical D; public condition() asserts self.dim == d_feat
307        let beta = 1.0 / noise_variance;
308
309        let a_mat = Mat::<f64>::from_fn(n_obs, d, |i, j| a[i * d + j]);
310        let y_mat = Mat::<f64>::from_fn(n_obs, 1, |i, _| y[i]);
311        let sigma_prior = Mat::<f64>::from_fn(d, d, |i, j| self.cov[i * d + j]);
312        let mu_prior = Mat::<f64>::from_fn(d, 1, |i, _| self.mean[i]);
313
314        // Step 1: Cholesky of Σ_prior → Σ_prior⁻¹ and Σ_prior⁻¹·μ_prior
315        let llt_prior = sigma_prior
316            .llt(Side::Lower)
317            .map_err(|_| BLRError::SingularMatrix)?;
318        let eye_d = Mat::<f64>::identity(d, d);
319        let sigma_prior_inv = llt_prior.solve(eye_d.as_ref()); // D×D
320        let prec_mu_prior = llt_prior.solve(mu_prior.as_ref()); // D×1
321
322        // Step 2: A^T A  (D×D)
323        let mut at_a = Mat::<f64>::zeros(d, d);
324        faer::linalg::matmul::matmul(
325            at_a.as_mut(),
326            Accum::Replace,
327            a_mat.as_ref().transpose(),
328            a_mat.as_ref(),
329            1.0_f64,
330            Par::Seq,
331        );
332
333        // Step 3: Precision matrix P = Σ_prior⁻¹ + β·(A^T A)  (D×D)
334        //         Exact Woodbury form — no isotropic approximation.
335        let mut precision = sigma_prior_inv;
336        for i in 0..d {
337            for j in 0..d {
338                precision[(i, j)] += beta * at_a[(i, j)];
339            }
340        }
341
342        // Step 4: Cholesky of P
343        let llt_post = precision
344            .llt(Side::Lower)
345            .map_err(|_| BLRError::SingularMatrix)?;
346
347        // Step 5: Σ_post = P⁻¹  (solve with identity)
348        let sigma_post = llt_post.solve(eye_d.as_ref());
349
350        // Step 6: RHS = Σ_prior⁻¹·μ_prior + β·A^T·y
351        let mut at_y = Mat::<f64>::zeros(d, 1);
352        faer::linalg::matmul::matmul(
353            at_y.as_mut(),
354            Accum::Replace,
355            a_mat.as_ref().transpose(),
356            y_mat.as_ref(),
357            1.0_f64,
358            Par::Seq,
359        );
360        let rhs = Mat::<f64>::from_fn(d, 1, |i, _| prec_mu_prior[(i, 0)] + beta * at_y[(i, 0)]);
361
362        // Step 7: μ_post = Σ_post·rhs
363        let mu_post = llt_post.solve(rhs.as_ref());
364
365        let new_mean: Vec<f64> = (0..d).map(|i| mu_post[(i, 0)]).collect();
366        let sigma_ref = sigma_post.as_ref();
367        let new_cov: Vec<f64> = (0..d)
368            .flat_map(|i| (0..d).map(move |j| sigma_ref[(i, j)]))
369            .collect();
370
371        Gaussian::new(new_mean, new_cov)
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_std_known() {
381        // 2D Gaussian with cov = [[4, 0], [0, 9]]
382        let g = Gaussian::new(vec![0.0, 0.0], vec![4.0, 0.0, 0.0, 9.0]).unwrap();
383        let std = g.std();
384        let tol = 1e-10;
385        assert!((std[0] - 2.0).abs() < tol, "std[0]={}", std[0]);
386        assert!((std[1] - 3.0).abs() < tol, "std[1]={}", std[1]);
387    }
388
389    #[test]
390    fn test_log_pdf_standard_normal() {
391        // log N(0; 0, I) = -D/2 * log(2*pi)
392        let d = 3usize;
393        let cov: Vec<f64> = (0..d * d)
394            .map(|k| if k % (d + 1) == 0 { 1.0 } else { 0.0 })
395            .collect();
396        let g = Gaussian::new(vec![0.0; d], cov).unwrap();
397        let lp = g.log_pdf(&vec![0.0; d]);
398        let expected = -(d as f64) / 2.0 * (2.0 * std::f64::consts::PI).ln();
399        assert!(
400            (lp - expected).abs() < 1e-10,
401            "log_pdf={lp:.6}, expected={expected:.6}"
402        );
403    }
404
405    #[test]
406    fn test_marginal() {
407        // 3D Gaussian with cov = diag([1,4,9])
408        let cov = vec![1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 9.0];
409        let g = Gaussian::new(vec![1.0, 2.0, 3.0], cov).unwrap();
410        let (m, s) = g.marginal(1);
411        assert!((m - 2.0).abs() < 1e-10);
412        assert!((s - 2.0).abs() < 1e-10);
413    }
414
415    #[test]
416    fn test_dim_mismatch() {
417        let result = Gaussian::new(vec![0.0; 3], vec![1.0; 4]);
418        assert!(result.is_err());
419    }
420
421    // ── condition() tests ─────────────────────────────────────────────────────
422
423    #[test]
424    fn test_condition_gram_form_analytic_2d() {
425        // μ_prior=0, Σ_prior=I₂, A=I₂, σ²=1, y=[1,2]
426        // → Σ_post = 0.5·I₂, μ_post = [0.5, 1.0]
427        let g = Gaussian::new(vec![0.0, 0.0], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
428        let a = vec![1.0_f64, 0.0, 0.0, 1.0]; // 2×2 identity row-major
429        let y = vec![1.0_f64, 2.0];
430        let post = g.condition_gram_form(&a, 2, &y, 1.0).unwrap();
431        let tol = 1e-12;
432        assert!((post.mean[0] - 0.5).abs() < tol, "mean[0]={}", post.mean[0]);
433        assert!((post.mean[1] - 1.0).abs() < tol, "mean[1]={}", post.mean[1]);
434        assert!((post.cov[0] - 0.5).abs() < tol, "cov[0,0]={}", post.cov[0]);
435        assert!((post.cov[1]).abs() < tol, "cov[0,1]={}", post.cov[1]);
436        assert!((post.cov[2]).abs() < tol, "cov[1,0]={}", post.cov[2]);
437        assert!((post.cov[3] - 0.5).abs() < tol, "cov[1,1]={}", post.cov[3]);
438    }
439
440    #[test]
441    fn test_condition_precision_form_analytic_2d() {
442        // Same analytic case as gram test — both forms must agree.
443        let g = Gaussian::new(vec![0.0, 0.0], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
444        let a = vec![1.0_f64, 0.0, 0.0, 1.0];
445        let y = vec![1.0_f64, 2.0];
446        let post = g.condition_precision_form(&a, 2, &y, 1.0).unwrap();
447        let tol = 1e-12;
448        assert!((post.mean[0] - 0.5).abs() < tol, "mean[0]={}", post.mean[0]);
449        assert!((post.mean[1] - 1.0).abs() < tol, "mean[1]={}", post.mean[1]);
450        assert!((post.cov[0] - 0.5).abs() < tol, "cov[0,0]={}", post.cov[0]);
451        assert!((post.cov[1]).abs() < tol, "cov[0,1]={}", post.cov[1]);
452        assert!((post.cov[2]).abs() < tol, "cov[1,0]={}", post.cov[2]);
453        assert!((post.cov[3] - 0.5).abs() < tol, "cov[1,1]={}", post.cov[3]);
454    }
455
456    #[test]
457    fn test_condition_parity_n8_d6() {
458        let n = 8usize;
459        let d = 6usize;
460        // Deterministic synthetic A (n×d) — values spread in (-1, 1)
461        let a: Vec<f64> = (0..n * d)
462            .map(|k| {
463                let seed = (k as f64 * 0.3141592653589793).sin();
464                seed * 0.5 // scale to keep conditioning reasonable
465            })
466            .collect();
467        // Synthetic y
468        let y: Vec<f64> = (0..n).map(|i| (i as f64 * 0.7).cos()).collect();
469        // Prior: identity covariance, zero mean
470        let cov_prior: Vec<f64> = (0..d * d)
471            .map(|k| if k % (d + 1) == 0 { 1.0 } else { 0.0 })
472            .collect();
473        let noise_variance = 0.5_f64;
474
475        let g_gram = Gaussian::new(vec![0.0; d], cov_prior.clone()).unwrap();
476        let g_prec = Gaussian::new(vec![0.0; d], cov_prior).unwrap();
477
478        let post_gram = g_gram
479            .condition_gram_form(&a, n, &y, noise_variance)
480            .unwrap();
481        let post_prec = g_prec
482            .condition_precision_form(&a, n, &y, noise_variance)
483            .unwrap();
484
485        let tol = 1e-10;
486        for i in 0..d {
487            assert!(
488                (post_gram.mean[i] - post_prec.mean[i]).abs() < tol,
489                "mean[{}]: gram={}, prec={}",
490                i,
491                post_gram.mean[i],
492                post_prec.mean[i]
493            );
494        }
495        for k in 0..d * d {
496            assert!(
497                (post_gram.cov[k] - post_prec.cov[k]).abs() < tol,
498                "cov[{}]: gram={}, prec={}",
499                k,
500                post_gram.cov[k],
501                post_prec.cov[k]
502            );
503        }
504    }
505
506    #[test]
507    fn test_condition_dispatch_n_lt_d() {
508        // N=3 < D=10 → should dispatch to gram form
509        let n = 3usize;
510        let d = 10usize;
511        let a: Vec<f64> = (0..n * d).map(|k| (k as f64 * 0.17).sin()).collect();
512        let y: Vec<f64> = (0..n).map(|i| i as f64 + 1.0).collect();
513        let cov: Vec<f64> = (0..d * d)
514            .map(|k| if k % (d + 1) == 0 { 2.0 } else { 0.0 })
515            .collect();
516        let mean = vec![0.5_f64; d];
517        let noise_variance = 0.3_f64;
518
519        let g1 = Gaussian::new(mean.clone(), cov.clone()).unwrap();
520        let g2 = Gaussian::new(mean, cov).unwrap();
521
522        let post_dispatch = g1.condition(&a, n, d, &y, noise_variance).unwrap();
523        let post_gram = g2.condition_gram_form(&a, n, &y, noise_variance).unwrap();
524
525        for i in 0..d {
526            assert_eq!(
527                post_dispatch.mean[i], post_gram.mean[i],
528                "mean[{}] mismatch — dispatch did not route to gram form",
529                i
530            );
531        }
532    }
533
534    #[test]
535    fn test_condition_dispatch_n_gt_d() {
536        // N=100 >= D=6 → should dispatch to precision form
537        let n = 100usize;
538        let d = 6usize;
539        let a: Vec<f64> = (0..n * d).map(|k| (k as f64 * 0.13).sin()).collect();
540        let y: Vec<f64> = (0..n).map(|i| (i as f64 * 0.23).cos()).collect();
541        let cov: Vec<f64> = (0..d * d)
542            .map(|k| if k % (d + 1) == 0 { 1.0 } else { 0.0 })
543            .collect();
544        let mean = vec![0.0_f64; d];
545        let noise_variance = 1.0_f64;
546
547        let g1 = Gaussian::new(mean.clone(), cov.clone()).unwrap();
548        let g2 = Gaussian::new(mean, cov).unwrap();
549
550        let post_dispatch = g1.condition(&a, n, d, &y, noise_variance).unwrap();
551        let post_prec = g2
552            .condition_precision_form(&a, n, &y, noise_variance)
553            .unwrap();
554
555        for i in 0..d {
556            assert_eq!(
557                post_dispatch.mean[i], post_prec.mean[i],
558                "mean[{}] mismatch — dispatch did not route to precision form",
559                i
560            );
561        }
562    }
563}