Skip to main content

nuts_rs/transform/
low_rank.rs

1//! Augment the diagonal transformation with a low-rank spectral correction for correlated posteriors.
2
3use std::fmt::Debug;
4use std::iter::repeat_n;
5
6use faer::{Col, ColRef, Mat, MatRef};
7use nuts_derive::Storable;
8use serde::{Deserialize, Serialize};
9
10use crate::transform::{DiagMassMatrix, Transformation};
11use crate::{Math, sampler_stats::SamplerStats};
12
13pub fn mat_all_finite(mat: &MatRef<f64>) -> bool {
14    let mut ok = true;
15    faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
16    ok
17}
18
19fn col_all_finite(mat: &ColRef<f64>) -> bool {
20    let mut ok = true;
21    faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
22    ok
23}
24
25/// The low-rank correction to the affine transformation.
26///
27/// Stores U (eigenvectors), λ^{1/2} (used for F and J_F), λ^{-1/2} (used for F⁻¹), and
28/// the precomputed low-rank contribution to log|det J_{F⁻¹}|.
29struct InnerMatrix<M: Math> {
30    vecs: M::EigVectors,
31    /// λ^{1/2} — used for the forward map F and its Jacobian J_F
32    vals_sqrt: M::EigValues,
33    /// λ^{-1/2} — used for the inverse position transform F⁻¹
34    vals_sqrt_inv: M::EigValues,
35    /// -½ Σ log(λᵢ) — low-rank contribution to log|det J_{F⁻¹}|, precomputed
36    /// so we never need to pull eigenvalues back from a device (e.g. GPU).
37    logdet_contribution: f64,
38    num_eigenvalues: u64,
39}
40
41impl<M: Math> Debug for InnerMatrix<M> {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("InnerMatrix")
44            .field("vecs", &"<eig vectors>")
45            .field("vals_sqrt", &"<sqrt eig values>")
46            .field("vals_sqrt_inv", &"<inv sqrt eig values>")
47            .field("logdet_contribution", &self.logdet_contribution)
48            .field("num_eigenvalues", &self.num_eigenvalues)
49            .finish()
50    }
51}
52
53impl<M: Math> InnerMatrix<M> {
54    fn new(math: &mut M, mut vals: Col<f64>, vecs: Mat<f64>) -> Self {
55        // Precompute -½ Σ log(λᵢ) while vals still holds the raw eigenvalues.
56        let logdet_contribution: f64 = vals.iter().map(|&v| -0.5 * v.ln()).sum();
57        let num_eigenvalues = vals.nrows() as u64;
58
59        let vecs = math.new_eig_vectors(
60            vecs.col_iter()
61                .map(|col| col.try_as_col_major().unwrap().as_slice()),
62        );
63
64        // λ^{1/2} — needed for the forward map F and its Jacobian J_F
65        vals.iter_mut().for_each(|x| *x = x.sqrt());
66        let vals_sqrt = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
67
68        // λ^{-1/2} — needed for the inverse position transform F⁻¹
69        vals.iter_mut().for_each(|x| *x = x.recip());
70        let vals_sqrt_inv = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
71
72        Self {
73            vecs,
74            vals_sqrt,
75            vals_sqrt_inv,
76            logdet_contribution,
77            num_eigenvalues,
78        }
79    }
80
81    fn logdet(&self) -> f64 {
82        self.logdet_contribution
83    }
84}
85
86/// Low-rank + diagonal affine transformation.
87///
88/// The full forward map (adapted → target) is
89///
90///   F(y) = σ ⊙ (I + U (diag(λ)^{1/2} − I) Uᵀ) y + μ
91///
92/// so the inverse (target → adapted) is
93///
94///   F⁻¹(x) = (I + U (diag(λ)^{-1/2} − I) Uᵀ) ((x − μ) ⊙ σ⁻¹)
95///
96/// The Jacobian of F is  J_F = diag(σ) (I + U (diag(λ)^{1/2} − I) Uᵀ),
97/// so  log|det J_{F⁻¹}| = Σ log(σᵢ⁻¹) − ½ Σ log(λᵢ).
98///
99/// In the adapted space the mass matrix is the identity; leapfrog steps
100/// operate entirely in that space.  When no eigenvectors are available
101/// (early adaptation) the transform falls back to the pure diagonal case.
102pub struct LowRankMassMatrix<M: Math> {
103    diag: DiagMassMatrix<M>,
104    inner: Option<InnerMatrix<M>>,
105    settings: LowRankSettings,
106    logdet: f64,
107    /// Monotonically increasing id; bumped whenever the matrix changes.
108    id: i64,
109}
110
111impl<M: Math> Debug for LowRankMassMatrix<M> {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("LowRankMassMatrix")
114            .field("diag", &self.diag)
115            .field("inner", &self.inner)
116            .field("settings", &self.settings)
117            .field("id", &self.id)
118            .finish()
119    }
120}
121
122impl<M: Math> LowRankMassMatrix<M> {
123    pub fn new(math: &mut M, settings: LowRankSettings) -> Self {
124        Self {
125            diag: DiagMassMatrix::new(math, settings.store_mass_matrix),
126            settings,
127            logdet: 0f64,
128            inner: None,
129            id: -1,
130        }
131    }
132
133    /// Initialise from the gradient at the first draw only (no mean / covariance information yet).
134    pub fn update_from_grad(
135        &mut self,
136        math: &mut M,
137        pos: &M::Vector,
138        grad: &M::Vector,
139        fill_invalid: f64,
140        clamp: (f64, f64),
141    ) {
142        self.inner = None;
143        self.diag
144            .update_diag_grad(math, pos, grad, fill_invalid, clamp);
145        self.logdet = self.diag.logdet();
146        self.id += 1;
147    }
148
149    /// Full update from a window of draws and scores.
150    ///
151    /// * `stds`  — diagonal scales σ
152    /// * `mean`  — optimal translation μ* = x̄ + σ² ⊙ ᾱ (in target space)
153    /// * `vals`  — filtered eigenvalues λ of the SPD geometric mean
154    /// * `vecs`  — corresponding eigenvectors U (columns, back-projected to ℝᵈ)
155    pub fn update(
156        &mut self,
157        math: &mut M,
158        stds: Col<f64>,
159        mean: Col<f64>,
160        vals: Col<f64>,
161        vecs: Mat<f64>,
162    ) {
163        if (!col_all_finite(&stds.as_ref())) | (!col_all_finite(&mean.as_ref())) {
164            return;
165        }
166        if (!col_all_finite(&vals.as_ref())) | (!mat_all_finite(&vecs.as_ref())) {
167            return;
168        }
169
170        let mut stds_array = math.new_array();
171        math.read_from_slice(&mut stds_array, stds.try_as_col_major().unwrap().as_slice());
172        let mut mean_array = math.new_array();
173        math.read_from_slice(&mut mean_array, mean.try_as_col_major().unwrap().as_slice());
174        self.diag.set_transform(math, &stds_array, &mean_array);
175
176        let inner = InnerMatrix::new(math, vals, vecs);
177        self.logdet = inner.logdet() + self.diag.logdet();
178        self.inner = Some(inner);
179        self.id += 1;
180    }
181}
182
183#[derive(Clone, Debug, Copy, Serialize, Deserialize)]
184pub struct LowRankSettings {
185    pub store_mass_matrix: bool,
186    pub gamma: f64,
187    pub eigval_cutoff: f64,
188}
189
190impl Default for LowRankSettings {
191    fn default() -> Self {
192        Self {
193            store_mass_matrix: false,
194            gamma: 1e-5,
195            eigval_cutoff: 2f64,
196        }
197    }
198}
199
200#[derive(Debug, Storable)]
201pub struct MatrixStats {
202    /// The transformation version counter at the time of this update.
203    /// `Some` only on draws where the transformation changed.
204    #[storable(event = "transformation_update")]
205    pub transformation_update_id: Option<i64>,
206    #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
207    pub mass_matrix_eigvals: Option<Vec<f64>>,
208    #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
209    pub mass_matrix_stds: Option<Vec<f64>>,
210    #[storable(event = "transformation_update")]
211    pub num_eigenvalues: Option<u64>,
212}
213
214impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
215    type Stats = MatrixStats;
216    type StatsOptions = i64;
217
218    fn extract_stats(&self, math: &mut M, last_id: Self::StatsOptions) -> Self::Stats {
219        if self.id != last_id {
220            let num_eigenvalues = Some(
221                self.inner
222                    .as_ref()
223                    .map(|inner| inner.num_eigenvalues)
224                    .unwrap_or(0),
225            );
226            if self.settings.store_mass_matrix {
227                let stds = Some(math.box_array(self.diag.stds()));
228                let eigvals = self
229                    .inner
230                    .as_ref()
231                    .map(|inner| math.eigs_as_array(&inner.vals_sqrt));
232                let mut eigvals = eigvals.map(|x| x.into_vec());
233                if let Some(ref mut eigvals) = eigvals {
234                    eigvals.extend(repeat_n(
235                        f64::NAN,
236                        stds.as_ref().unwrap().len() - eigvals.len(),
237                    ));
238                }
239                MatrixStats {
240                    transformation_update_id: Some(self.id),
241                    mass_matrix_eigvals: eigvals,
242                    mass_matrix_stds: stds.map(|x| x.into_vec()),
243                    num_eigenvalues,
244                }
245            } else {
246                MatrixStats {
247                    transformation_update_id: Some(self.id),
248                    mass_matrix_eigvals: None,
249                    mass_matrix_stds: None,
250                    num_eigenvalues,
251                }
252            }
253        } else {
254            MatrixStats {
255                transformation_update_id: None,
256                mass_matrix_eigvals: None,
257                mass_matrix_stds: None,
258                num_eigenvalues: None,
259            }
260        }
261    }
262}
263
264impl<M: Math> Transformation<M> for LowRankMassMatrix<M> {
265    fn init_from_untransformed_position(
266        &self,
267        math: &mut M,
268        untransformed_position: &M::Vector,
269        untransformed_gradient: &mut M::Vector,
270        transformed_position: &mut M::Vector,
271        transformed_gradient: &mut M::Vector,
272    ) -> Result<(f64, f64), M::LogpErr> {
273        let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
274        self.compute_transformed_position(math, untransformed_position, transformed_position);
275        self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
276        Ok((logp, self.logdet(math)))
277    }
278
279    fn init_from_transformed_position(
280        &self,
281        math: &mut M,
282        untransformed_position: &mut M::Vector,
283        untransformed_gradient: &mut M::Vector,
284        transformed_position: &M::Vector,
285        transformed_gradient: &mut M::Vector,
286    ) -> Result<(f64, f64), M::LogpErr> {
287        self.compute_untransformed_position(math, transformed_position, untransformed_position);
288        let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
289        self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
290        Ok((logp, self.logdet(math)))
291    }
292
293    fn inv_transform_normalize(
294        &self,
295        math: &mut M,
296        untransformed_position: &M::Vector,
297        untransformed_gradient: &M::Vector,
298        transformed_position: &mut M::Vector,
299        transformed_gradient: &mut M::Vector,
300    ) -> Result<f64, M::LogpErr> {
301        self.compute_transformed_position(math, untransformed_position, transformed_position);
302        self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
303        Ok(self.logdet(math))
304    }
305
306    fn transformation_id(&self, _math: &mut M) -> i64 {
307        self.id
308    }
309
310    fn next_stats_options(&self, _math: &mut M, _current: i64) -> i64 {
311        self.id
312    }
313}
314
315impl<M: Math> LowRankMassMatrix<M> {
316    fn compute_transformed_position(
317        &self,
318        math: &mut M,
319        untransformed_position: &M::Vector,
320        transformed_position: &mut M::Vector,
321    ) {
322        math.axpy_out(
323            &self.diag.mean(),
324            &untransformed_position,
325            -1.0,
326            transformed_position,
327        );
328        math.array_mult_inplace(transformed_position, self.diag.inv_stds());
329
330        if let Some(inner) = &self.inner {
331            math.apply_lowrank_transform_inplace(
332                &inner.vecs,
333                &inner.vals_sqrt_inv,
334                transformed_position,
335            );
336        }
337    }
338
339    fn compute_untransformed_position(
340        &self,
341        math: &mut M,
342        transformed_position: &M::Vector,
343        untransformed_position: &mut M::Vector,
344    ) {
345        match &self.inner {
346            None => {
347                math.array_mult(
348                    transformed_position,
349                    &self.diag.stds(),
350                    untransformed_position,
351                );
352            }
353            Some(inner) => {
354                math.apply_lowrank_transform(
355                    &inner.vecs,
356                    &inner.vals_sqrt,
357                    transformed_position,
358                    untransformed_position,
359                );
360                math.array_mult_inplace(untransformed_position, &self.diag.stds());
361            }
362        }
363        math.axpy(&self.diag.mean(), untransformed_position, 1.0);
364    }
365
366    fn compute_transformed_gradient(
367        &self,
368        math: &mut M,
369        untransformed_gradient: &M::Vector,
370        transformed_gradient: &mut M::Vector,
371    ) {
372        math.array_mult(
373            untransformed_gradient,
374            self.diag.stds(),
375            transformed_gradient,
376        );
377
378        if let Some(inner) = &self.inner {
379            math.apply_lowrank_transform_inplace(
380                &inner.vecs,
381                &inner.vals_sqrt,
382                transformed_gradient,
383            );
384        }
385    }
386
387    /// log|det J_{F⁻¹}| = Σ log(σᵢ⁻¹) − ½ Σ log(λᵢ)
388    fn logdet(&self, _math: &mut M) -> f64 {
389        self.logdet
390    }
391}