Skip to main content

nuts_rs/transform/
low_rank.rs

1//! Augment the diagonal transformation with a low-rank spectral correction for correlated posteriors.
2
3use std::fmt::Debug;
4use std::iter::repeat_n;
5
6use faer::{Col, ColRef, Mat, MatRef};
7use nuts_derive::Storable;
8use serde::{Deserialize, Serialize};
9
10use crate::transform::{DiagMassMatrix, Transformation};
11use crate::{Math, sampler_stats::SamplerStats};
12
13pub fn mat_all_finite(mat: &MatRef<f64>) -> bool {
14    let mut ok = true;
15    faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
16    ok
17}
18
19fn col_all_finite(mat: &ColRef<f64>) -> bool {
20    let mut ok = true;
21    faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
22    ok
23}
24
25/// The low-rank correction to the affine transformation.
26///
27/// Stores U (eigenvectors), λ^{1/2} (used for F and J_F), λ^{-1/2} (used for F⁻¹), and
28/// the precomputed low-rank contribution to log|det J_{F⁻¹}|.
29struct InnerMatrix<M: Math> {
30    vecs: M::EigVectors,
31    /// λ^{1/2} — used for the forward map F and its Jacobian J_F
32    vals_sqrt: M::EigValues,
33    /// λ^{-1/2} — used for the inverse position transform F⁻¹
34    vals_sqrt_inv: M::EigValues,
35    /// -½ Σ log(λᵢ) — low-rank contribution to log|det J_{F⁻¹}|, precomputed
36    /// so we never need to pull eigenvalues back from a device (e.g. GPU).
37    logdet_contribution: f64,
38    mu: M::Vector,
39    num_eigenvalues: u64,
40}
41
42impl<M: Math> Debug for InnerMatrix<M> {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("InnerMatrix")
45            .field("vecs", &"<eig vectors>")
46            .field("vals_sqrt", &"<sqrt eig values>")
47            .field("vals_sqrt_inv", &"<inv sqrt eig values>")
48            .field("logdet_contribution", &self.logdet_contribution)
49            .field("num_eigenvalues", &self.num_eigenvalues)
50            .field("mu", &self.mu)
51            .finish()
52    }
53}
54
55impl<M: Math> InnerMatrix<M> {
56    fn new(math: &mut M, mut vals: Col<f64>, vecs: Mat<f64>, mu: Col<f64>) -> Self {
57        // Precompute -½ Σ log(λᵢ) while vals still holds the raw eigenvalues.
58        let logdet_contribution: f64 = vals.iter().map(|&v| -0.5 * v.ln()).sum();
59        let num_eigenvalues = vals.nrows() as u64;
60
61        let vecs = math.new_eig_vectors(
62            vecs.col_iter()
63                .map(|col| col.try_as_col_major().unwrap().as_slice()),
64        );
65
66        // λ^{1/2} — needed for the forward map F and its Jacobian J_F
67        vals.iter_mut().for_each(|x| *x = x.sqrt());
68        let vals_sqrt = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
69
70        // λ^{-1/2} — needed for the inverse position transform F⁻¹
71        vals.iter_mut().for_each(|x| *x = x.recip());
72        let vals_sqrt_inv = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
73
74        let mu = {
75            let mut array = math.new_array();
76            math.read_from_slice(&mut array, mu.try_as_col_major().unwrap().as_slice());
77            array
78        };
79
80        Self {
81            vecs,
82            vals_sqrt,
83            vals_sqrt_inv,
84            logdet_contribution,
85            mu,
86            num_eigenvalues,
87        }
88    }
89
90    fn logdet(&self) -> f64 {
91        self.logdet_contribution
92    }
93}
94
95/// Low-rank + diagonal affine transformation.
96///
97/// The full forward map (adapted → target) is
98///
99///   F(y) = σ ⊙ (I + U (diag(λ)^{1/2} − I) Uᵀ) y + μ
100///
101/// so the inverse (target → adapted) is
102///
103///   F⁻¹(x) = (I + U (diag(λ)^{-1/2} − I) Uᵀ) ((x − μ) ⊙ σ⁻¹)
104///
105/// The Jacobian of F is  J_F = diag(σ) (I + U (diag(λ)^{1/2} − I) Uᵀ),
106/// so  log|det J_{F⁻¹}| = Σ log(σᵢ⁻¹) − ½ Σ log(λᵢ).
107///
108/// In the adapted space the mass matrix is the identity; leapfrog steps
109/// operate entirely in that space.  When no eigenvectors are available
110/// (early adaptation) the transform falls back to the pure diagonal case.
111pub struct LowRankMassMatrix<M: Math> {
112    diag: DiagMassMatrix<M>,
113    inner: Option<InnerMatrix<M>>,
114    settings: LowRankSettings,
115    logdet: f64,
116    /// Monotonically increasing id; bumped whenever the matrix changes.
117    id: i64,
118}
119
120impl<M: Math> Debug for LowRankMassMatrix<M> {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("LowRankMassMatrix")
123            .field("diag", &self.diag)
124            .field("inner", &self.inner)
125            .field("settings", &self.settings)
126            .field("id", &self.id)
127            .finish()
128    }
129}
130
131impl<M: Math> LowRankMassMatrix<M> {
132    pub fn new(math: &mut M, settings: LowRankSettings) -> Self {
133        Self {
134            diag: DiagMassMatrix::new(math, settings.store_mass_matrix),
135            settings,
136            logdet: 0f64,
137            inner: None,
138            id: -1,
139        }
140    }
141
142    /// Initialise from the gradient at the first draw only (no mean / covariance information yet).
143    pub fn update_from_grad(
144        &mut self,
145        math: &mut M,
146        pos: &M::Vector,
147        grad: &M::Vector,
148        fill_invalid: f64,
149        clamp: (f64, f64),
150    ) {
151        self.inner = None;
152        self.diag
153            .update_diag_grad(math, pos, grad, fill_invalid, clamp);
154        self.logdet = self.diag.logdet();
155        self.id += 1;
156    }
157
158    /// Full update from a window of draws and scores.
159    ///
160    /// * `stds`  — diagonal scales σ
161    /// * `mean`  — optimal translation μ* = x̄ + σ² ⊙ ᾱ (in target space)
162    /// * `vals`  — filtered eigenvalues λ of the SPD geometric mean
163    /// * `vecs`  — corresponding eigenvectors U (columns, back-projected to ℝᵈ)
164    pub fn update(
165        &mut self,
166        math: &mut M,
167        stds: Col<f64>,
168        mean: Col<f64>,
169        vals: Col<f64>,
170        vecs: Mat<f64>,
171        mean_low_rank: Col<f64>,
172    ) {
173        if (!col_all_finite(&stds.as_ref())) | (!col_all_finite(&mean.as_ref())) {
174            return;
175        }
176        if (!col_all_finite(&vals.as_ref())) | (!mat_all_finite(&vecs.as_ref())) {
177            return;
178        }
179
180        let mut stds_array = math.new_array();
181        math.read_from_slice(&mut stds_array, stds.try_as_col_major().unwrap().as_slice());
182        let mut mean_array = math.new_array();
183        math.read_from_slice(&mut mean_array, mean.try_as_col_major().unwrap().as_slice());
184        self.diag.set_transform(math, &stds_array, &mean_array);
185
186        let inner = InnerMatrix::new(math, vals, vecs, mean_low_rank);
187        self.logdet = inner.logdet() + self.diag.logdet();
188        self.inner = Some(inner);
189        self.id += 1;
190    }
191}
192
193#[derive(Clone, Debug, Copy, Serialize, Deserialize)]
194pub struct LowRankSettings {
195    pub store_mass_matrix: bool,
196    pub gamma: f64,
197    pub eigval_cutoff: f64,
198}
199
200impl Default for LowRankSettings {
201    fn default() -> Self {
202        Self {
203            store_mass_matrix: false,
204            gamma: 1e-5,
205            eigval_cutoff: 2f64,
206        }
207    }
208}
209
210#[derive(Debug, Storable)]
211pub struct MatrixStats {
212    /// The transformation version counter at the time of this update.
213    /// `Some` only on draws where the transformation changed.
214    #[storable(event = "transformation_update")]
215    pub transformation_update_id: Option<i64>,
216    #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
217    pub mass_matrix_eigvals: Option<Vec<f64>>,
218    #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
219    pub mass_matrix_stds: Option<Vec<f64>>,
220    #[storable(event = "transformation_update")]
221    pub num_eigenvalues: Option<u64>,
222}
223
224impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
225    type Stats = MatrixStats;
226    type StatsOptions = i64;
227
228    fn extract_stats(&self, math: &mut M, last_id: Self::StatsOptions) -> Self::Stats {
229        if self.id != last_id {
230            let num_eigenvalues = Some(
231                self.inner
232                    .as_ref()
233                    .map(|inner| inner.num_eigenvalues)
234                    .unwrap_or(0),
235            );
236            if self.settings.store_mass_matrix {
237                let stds = Some(math.box_array(self.diag.stds()));
238                let eigvals = self
239                    .inner
240                    .as_ref()
241                    .map(|inner| math.eigs_as_array(&inner.vals_sqrt));
242                let mut eigvals = eigvals.map(|x| x.into_vec());
243                if let Some(ref mut eigvals) = eigvals {
244                    eigvals.extend(repeat_n(
245                        f64::NAN,
246                        stds.as_ref().unwrap().len() - eigvals.len(),
247                    ));
248                }
249                MatrixStats {
250                    transformation_update_id: Some(self.id),
251                    mass_matrix_eigvals: eigvals,
252                    mass_matrix_stds: stds.map(|x| x.into_vec()),
253                    num_eigenvalues,
254                }
255            } else {
256                MatrixStats {
257                    transformation_update_id: Some(self.id),
258                    mass_matrix_eigvals: None,
259                    mass_matrix_stds: None,
260                    num_eigenvalues,
261                }
262            }
263        } else {
264            MatrixStats {
265                transformation_update_id: None,
266                mass_matrix_eigvals: None,
267                mass_matrix_stds: None,
268                num_eigenvalues: None,
269            }
270        }
271    }
272}
273
274impl<M: Math> Transformation<M> for LowRankMassMatrix<M> {
275    fn init_from_untransformed_position(
276        &self,
277        math: &mut M,
278        untransformed_position: &M::Vector,
279        untransformed_gradient: &mut M::Vector,
280        transformed_position: &mut M::Vector,
281        transformed_gradient: &mut M::Vector,
282    ) -> Result<(f64, f64), M::LogpErr> {
283        let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
284        self.compute_transformed_position(math, untransformed_position, transformed_position);
285        self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
286        Ok((logp, self.logdet(math)))
287    }
288
289    fn init_from_transformed_position(
290        &self,
291        math: &mut M,
292        untransformed_position: &mut M::Vector,
293        untransformed_gradient: &mut M::Vector,
294        transformed_position: &M::Vector,
295        transformed_gradient: &mut M::Vector,
296    ) -> Result<(f64, f64), M::LogpErr> {
297        self.compute_untransformed_position(math, transformed_position, untransformed_position);
298        let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
299        self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
300        Ok((logp, self.logdet(math)))
301    }
302
303    fn inv_transform_normalize(
304        &self,
305        math: &mut M,
306        untransformed_position: &M::Vector,
307        untransformed_gradient: &M::Vector,
308        transformed_position: &mut M::Vector,
309        transformed_gradient: &mut M::Vector,
310    ) -> Result<f64, M::LogpErr> {
311        self.compute_transformed_position(math, untransformed_position, transformed_position);
312        self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
313        Ok(self.logdet(math))
314    }
315
316    fn transformation_id(&self, _math: &mut M) -> i64 {
317        self.id
318    }
319
320    fn next_stats_options(&self, _math: &mut M, _current: i64) -> i64 {
321        self.id
322    }
323}
324
325impl<M: Math> LowRankMassMatrix<M> {
326    fn compute_transformed_position(
327        &self,
328        math: &mut M,
329        untransformed_position: &M::Vector,
330        transformed_position: &mut M::Vector,
331    ) {
332        math.axpy_out(
333            &self.diag.mean(),
334            &untransformed_position,
335            -1.0,
336            transformed_position,
337        );
338        math.array_mult_inplace(transformed_position, self.diag.inv_stds());
339
340        if let Some(inner) = &self.inner {
341            math.axpy(&inner.mu, transformed_position, -1.0);
342            math.apply_lowrank_transform_inplace(
343                &inner.vecs,
344                &inner.vals_sqrt_inv,
345                transformed_position,
346            );
347        }
348    }
349
350    fn compute_untransformed_position(
351        &self,
352        math: &mut M,
353        transformed_position: &M::Vector,
354        untransformed_position: &mut M::Vector,
355    ) {
356        match &self.inner {
357            None => {
358                math.array_mult(
359                    transformed_position,
360                    &self.diag.stds(),
361                    untransformed_position,
362                );
363            }
364            Some(inner) => {
365                math.apply_lowrank_transform(
366                    &inner.vecs,
367                    &inner.vals_sqrt,
368                    transformed_position,
369                    untransformed_position,
370                );
371
372                math.axpy(&inner.mu, untransformed_position, 1.0);
373                math.array_mult_inplace(untransformed_position, &self.diag.stds());
374            }
375        }
376        math.axpy(&self.diag.mean(), untransformed_position, 1.0);
377    }
378
379    fn compute_transformed_gradient(
380        &self,
381        math: &mut M,
382        untransformed_gradient: &M::Vector,
383        transformed_gradient: &mut M::Vector,
384    ) {
385        math.array_mult(
386            untransformed_gradient,
387            self.diag.stds(),
388            transformed_gradient,
389        );
390
391        if let Some(inner) = &self.inner {
392            math.apply_lowrank_transform_inplace(
393                &inner.vecs,
394                &inner.vals_sqrt,
395                transformed_gradient,
396            );
397        }
398    }
399
400    /// log|det J_{F⁻¹}| = Σ log(σᵢ⁻¹) − ½ Σ log(λᵢ)
401    fn logdet(&self, _math: &mut M) -> f64 {
402        self.logdet
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use faer::{Col, Mat};
409
410    use crate::Math;
411    use crate::math::CpuMath;
412    use crate::math::test_logps::NormalLogp;
413
414    use super::{LowRankMassMatrix, LowRankSettings};
415
416    fn make_math(dim: usize) -> CpuMath<NormalLogp> {
417        CpuMath::new(NormalLogp::new(dim, 0.0))
418    }
419
420    fn assert_close(a: &[f64], b: &[f64], tol: f64) {
421        assert_eq!(a.len(), b.len());
422        for (i, (ai, bi)) in a.iter().zip(b.iter()).enumerate() {
423            assert!(
424                (ai - bi).abs() <= tol,
425                "index {i}: {ai} vs {bi} (tol {tol})"
426            );
427        }
428    }
429
430    fn read_vec(math: &mut CpuMath<NormalLogp>, v: &Col<f64>) -> Vec<f64> {
431        let mut out = vec![0f64; math.dim()];
432        math.write_to_slice(v, &mut out);
433        out
434    }
435
436    /// diagonal-only: compute_transformed_position ∘ compute_untransformed_position = id
437    #[test]
438    fn test_diagonal_round_trip() {
439        let mut math = make_math(3);
440        let stds = Col::from_fn(3, |i| [1.0f64, 2.0, 3.0][i]);
441        let mean = Col::from_fn(3, |i| [0.5f64, -1.0, 2.0][i]);
442        let vals = Col::zeros(0);
443        let vecs = Mat::zeros(3, 0);
444        let mu = Col::zeros(3);
445        let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
446        mass.update(&mut math, stds, mean, vals, vecs, mu);
447
448        let x_orig = [1.5f64, -0.3, 4.2];
449        let mut untransformed = math.new_array();
450        let mut transformed = math.new_array();
451        let mut recovered = math.new_array();
452        math.read_from_slice(&mut untransformed, &x_orig);
453
454        mass.compute_transformed_position(&mut math, &untransformed, &mut transformed);
455        mass.compute_untransformed_position(&mut math, &transformed, &mut recovered);
456
457        assert_close(&read_vec(&mut math, &recovered), &x_orig, 1e-12);
458    }
459
460    /// diagonal-only: compute_untransformed_position ∘ compute_transformed_position = id
461    #[test]
462    fn test_diagonal_round_trip_reverse() {
463        let mut math = make_math(3);
464        let stds = Col::from_fn(3, |i| [1.0f64, 2.0, 3.0][i]);
465        let mean = Col::from_fn(3, |i| [0.5f64, -1.0, 2.0][i]);
466        let vals = Col::zeros(0);
467        let vecs = Mat::zeros(3, 0);
468        let mu = Col::zeros(3);
469        let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
470        mass.update(&mut math, stds, mean, vals, vecs, mu);
471
472        let z_orig = [0.7f64, -1.1, 0.3];
473        let mut transformed = math.new_array();
474        let mut untransformed = math.new_array();
475        let mut recovered = math.new_array();
476        math.read_from_slice(&mut transformed, &z_orig);
477
478        mass.compute_untransformed_position(&mut math, &transformed, &mut untransformed);
479        mass.compute_transformed_position(&mut math, &untransformed, &mut recovered);
480
481        assert_close(&read_vec(&mut math, &recovered), &z_orig, 1e-12);
482    }
483
484    /// low-rank: compute_transformed_position ∘ compute_untransformed_position = id
485    #[test]
486    fn test_lowrank_round_trip() {
487        let mut math = make_math(3);
488        // rank-1 correction along e_1 with eigenvalue 4
489        let stds = Col::full(3, 1.0f64);
490        let mean = Col::from_fn(3, |i| [1.0f64, -0.5, 0.0][i]);
491        let vals = faer::col![4.0f64];
492        let mut vecs = Mat::zeros(3, 1);
493        vecs[(0, 0)] = 1.0;
494        let mu = Col::from_fn(3, |i| [0.2f64, -0.1, 0.0][i]);
495        let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
496        mass.update(&mut math, stds, mean, vals, vecs, mu);
497
498        let x_orig = [2.0f64, 0.5, -1.3];
499        let mut untransformed = math.new_array();
500        let mut transformed = math.new_array();
501        let mut recovered = math.new_array();
502        math.read_from_slice(&mut untransformed, &x_orig);
503
504        mass.compute_transformed_position(&mut math, &untransformed, &mut transformed);
505        mass.compute_untransformed_position(&mut math, &transformed, &mut recovered);
506
507        assert_close(&read_vec(&mut math, &recovered), &x_orig, 1e-12);
508    }
509
510    /// low-rank: compute_untransformed_position ∘ compute_transformed_position = id
511    #[test]
512    fn test_lowrank_round_trip_reverse() {
513        let mut math = make_math(3);
514        let stds = Col::full(3, 1.0f64);
515        let mean = Col::from_fn(3, |i| [1.0f64, -0.5, 0.0][i]);
516        let vals = faer::col![4.0f64];
517        let mut vecs = Mat::zeros(3, 1);
518        vecs[(0, 0)] = 1.0;
519        let mu = Col::from_fn(3, |i| [0.2f64, -0.1, 0.0][i]);
520        let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
521        mass.update(&mut math, stds, mean, vals, vecs, mu);
522
523        let z_orig = [1.0f64, -0.3, 0.8];
524        let mut transformed = math.new_array();
525        let mut untransformed = math.new_array();
526        let mut recovered = math.new_array();
527        math.read_from_slice(&mut transformed, &z_orig);
528
529        mass.compute_untransformed_position(&mut math, &transformed, &mut untransformed);
530        mass.compute_transformed_position(&mut math, &untransformed, &mut recovered);
531
532        assert_close(&read_vec(&mut math, &recovered), &z_orig, 1e-12);
533    }
534}