nuts_rs/mass_matrix/
low_rank.rs

1use std::collections::VecDeque;
2use std::iter::repeat;
3
4use faer::{Col, ColRef, Mat, MatRef, Scale};
5use itertools::Itertools;
6use nuts_derive::Storable;
7use serde::Serialize;
8
9use super::adapt::MassMatrixAdaptStrategy;
10use super::diagonal::{DrawGradCollector, MassMatrix};
11use crate::{
12    Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point,
13    sampler_stats::SamplerStats,
14};
15
16fn mat_all_finite(mat: &MatRef<f64>) -> bool {
17    let mut ok = true;
18    faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
19    ok
20}
21
22fn col_all_finite(mat: &ColRef<f64>) -> bool {
23    let mut ok = true;
24    faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
25    ok
26}
27
28#[derive(Debug)]
29struct InnerMatrix<M: Math> {
30    vecs: M::EigVectors,
31    vals: M::EigValues,
32    vals_sqrt_inv: M::EigValues,
33    num_eigenvalues: u64,
34}
35
36impl<M: Math> InnerMatrix<M> {
37    fn new(math: &mut M, mut vals: Col<f64>, vecs: Mat<f64>) -> Self {
38        let vecs = math.new_eig_vectors(
39            vecs.col_iter()
40                .map(|col| col.try_as_col_major().unwrap().as_slice()),
41        );
42        let vals_math = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
43
44        vals.iter_mut().for_each(|x| *x = x.sqrt().recip());
45        let vals_inv_math = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
46
47        Self {
48            vecs,
49            vals: vals_math,
50            vals_sqrt_inv: vals_inv_math,
51            num_eigenvalues: vals.nrows() as u64,
52        }
53    }
54}
55
56#[derive(Debug)]
57pub struct LowRankMassMatrix<M: Math> {
58    variance: M::Vector,
59    stds: M::Vector,
60    inv_stds: M::Vector,
61    inner: Option<InnerMatrix<M>>,
62    settings: LowRankSettings,
63}
64
65impl<M: Math> LowRankMassMatrix<M> {
66    pub fn new(math: &mut M, settings: LowRankSettings) -> Self {
67        Self {
68            variance: math.new_array(),
69            inv_stds: math.new_array(),
70            stds: math.new_array(),
71            settings,
72            inner: None,
73        }
74    }
75
76    fn update_from_grad(
77        &mut self,
78        math: &mut M,
79        grad: &<M as Math>::Vector,
80        fill_invalid: f64,
81        clamp: (f64, f64),
82    ) {
83        math.array_update_var_inv_std_grad(
84            &mut self.variance,
85            &mut self.inv_stds,
86            grad,
87            fill_invalid,
88            clamp,
89        );
90        let mut vals = vec![0f64; math.dim()];
91        math.write_to_slice(&self.inv_stds, &mut vals);
92        vals.iter_mut().for_each(|x| *x = x.recip());
93        math.read_from_slice(&mut self.stds, &vals);
94    }
95
96    fn update(&mut self, math: &mut M, mut stds: Col<f64>, vals: Col<f64>, vecs: Mat<f64>) {
97        math.read_from_slice(&mut self.stds, stds.try_as_col_major().unwrap().as_slice());
98
99        stds.iter_mut().for_each(|x| *x = x.recip());
100        math.read_from_slice(
101            &mut self.inv_stds,
102            stds.try_as_col_major().unwrap().as_slice(),
103        );
104
105        stds.iter_mut().for_each(|x| *x = x.recip() * x.recip());
106        math.read_from_slice(
107            &mut self.variance,
108            stds.try_as_col_major().unwrap().as_slice(),
109        );
110
111        if col_all_finite(&vals.as_ref()) & mat_all_finite(&vecs.as_ref()) {
112            self.inner = Some(InnerMatrix::new(math, vals, vecs));
113        } else {
114            self.inner = None;
115        }
116    }
117}
118
119#[derive(Clone, Debug, Copy, Serialize)]
120pub struct LowRankSettings {
121    pub store_mass_matrix: bool,
122    pub gamma: f64,
123    pub eigval_cutoff: f64,
124}
125
126impl Default for LowRankSettings {
127    fn default() -> Self {
128        Self {
129            store_mass_matrix: false,
130            gamma: 1e-5,
131            eigval_cutoff: 2f64,
132        }
133    }
134}
135
136#[derive(Debug, Storable)]
137pub struct MatrixStats {
138    #[storable(dims("unconstrained_parameter"))]
139    pub mass_matrix_eigvals: Option<Vec<f64>>,
140    #[storable(dims("unconstrained_parameter"))]
141    pub mass_matrix_stds: Option<Vec<f64>>,
142    pub num_eigenvalues: u64,
143}
144
145impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
146    type Stats = MatrixStats;
147    type StatsOptions = ();
148
149    fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
150        if self.settings.store_mass_matrix {
151            let stds = Some(math.box_array(&self.stds));
152            let eigvals = self
153                .inner
154                .as_ref()
155                .map(|inner| math.eigs_as_array(&inner.vals));
156            let mut eigvals = eigvals.map(|x| x.into_vec());
157            if let Some(ref mut eigvals) = eigvals {
158                eigvals.extend(repeat(f64::NAN).take(stds.as_ref().unwrap().len() - eigvals.len()));
159            }
160            MatrixStats {
161                mass_matrix_eigvals: eigvals,
162                mass_matrix_stds: stds.map(|x| x.into_vec()),
163                num_eigenvalues: self
164                    .inner
165                    .as_ref()
166                    .map(|inner| inner.num_eigenvalues)
167                    .unwrap_or(0),
168            }
169        } else {
170            MatrixStats {
171                mass_matrix_eigvals: None,
172                mass_matrix_stds: None,
173                num_eigenvalues: self
174                    .inner
175                    .as_ref()
176                    .map(|inner| inner.num_eigenvalues)
177                    .unwrap_or(0),
178            }
179        }
180    }
181}
182
183impl<M: Math> MassMatrix<M> for LowRankMassMatrix<M> {
184    fn update_velocity(&self, math: &mut M, state: &mut EuclideanPoint<M>) {
185        let Some(inner) = self.inner.as_ref() else {
186            math.array_mult(&self.variance, &state.momentum, &mut state.velocity);
187            return;
188        };
189
190        math.array_mult_eigs(
191            &self.stds,
192            &state.momentum,
193            &mut state.velocity,
194            &inner.vecs,
195            &inner.vals,
196        );
197    }
198
199    fn update_kinetic_energy(&self, math: &mut M, state: &mut EuclideanPoint<M>) {
200        state.kinetic_energy = 0.5 * math.array_vector_dot(&state.momentum, &state.velocity);
201    }
202
203    fn randomize_momentum<R: rand::Rng + ?Sized>(
204        &self,
205        math: &mut M,
206        state: &mut EuclideanPoint<M>,
207        rng: &mut R,
208    ) {
209        let Some(inner) = self.inner.as_ref() else {
210            math.array_gaussian(rng, &mut state.momentum, &self.inv_stds);
211            return;
212        };
213
214        math.array_gaussian_eigs(
215            rng,
216            &mut state.momentum,
217            &self.inv_stds,
218            &inner.vals_sqrt_inv,
219            &inner.vecs,
220        );
221    }
222}
223
224/*
225#[derive(Debug, Clone)]
226pub struct Stats {
227    foreground_length: u64,
228    background_length: u64,
229    is_update: bool,
230    diag: Box<[f64]>,
231    eigvalues: Box<[f64]>,
232    eigvectors: Box<[f64]>,
233}
234*/
235
236#[derive(Debug)]
237pub struct LowRankMassMatrixStrategy {
238    draws: VecDeque<Vec<f64>>,
239    grads: VecDeque<Vec<f64>>,
240    ndim: usize,
241    background_split: usize,
242    settings: LowRankSettings,
243}
244
245impl LowRankMassMatrixStrategy {
246    pub fn new(ndim: usize, settings: LowRankSettings) -> Self {
247        let draws = VecDeque::with_capacity(100);
248        let grads = VecDeque::with_capacity(100);
249
250        Self {
251            draws,
252            grads,
253            ndim,
254            background_split: 0,
255            settings,
256        }
257    }
258
259    pub fn add_draw<M: Math>(&mut self, math: &mut M, point: &impl Point<M>) {
260        assert!(math.dim() == self.ndim);
261        let mut draw = vec![0f64; self.ndim];
262        math.write_to_slice(point.position(), &mut draw);
263        let mut grad = vec![0f64; self.ndim];
264        math.write_to_slice(point.gradient(), &mut grad);
265
266        self.draws.push_back(draw);
267        self.grads.push_back(grad);
268    }
269
270    pub fn clear(&mut self) {
271        self.draws.clear();
272        self.grads.clear();
273    }
274
275    pub fn update<M: Math>(&self, math: &mut M, matrix: &mut LowRankMassMatrix<M>) {
276        let draws_vec = &self.draws;
277        let grads_vec = &self.grads;
278
279        let ndraws = draws_vec.len();
280        assert!(grads_vec.len() == ndraws);
281
282        let mut draws: Mat<f64> = Mat::zeros(self.ndim, ndraws);
283        let mut grads: Mat<f64> = Mat::zeros(self.ndim, ndraws);
284
285        for (i, (draw, grad)) in draws_vec.iter().zip(grads_vec.iter()).enumerate() {
286            draws.col_as_slice_mut(i).copy_from_slice(&draw[..]);
287            grads.col_as_slice_mut(i).copy_from_slice(&grad[..]);
288        }
289
290        let Some((stds, vals, vecs)) = self.compute_update(draws, grads) else {
291            return;
292        };
293
294        matrix.update(math, stds, vals, vecs);
295    }
296
297    fn compute_update(
298        &self,
299        mut draws: Mat<f64>,
300        mut grads: Mat<f64>,
301    ) -> Option<(Col<f64>, Col<f64>, Mat<f64>)> {
302        let stds = rescale_points(&mut draws, &mut grads);
303
304        let svd_draws = draws.thin_svd().ok()?;
305        let svd_grads = grads.thin_svd().ok()?;
306
307        let subspace = faer::concat![[svd_draws.U(), svd_grads.U()]];
308
309        let subspace_qr = subspace.col_piv_qr();
310
311        let subspace_basis = subspace_qr.compute_thin_Q();
312
313        let draws_proj = subspace_basis.transpose() * (&draws);
314        let grads_proj = subspace_basis.transpose() * (&grads);
315
316        let (vals, vecs) = estimate_mass_matrix(draws_proj, grads_proj, self.settings.gamma)?;
317
318        let filtered = vals
319            .iter()
320            .zip(vecs.col_iter())
321            .filter(|&(&val, _)| {
322                (val > self.settings.eigval_cutoff) | (val < self.settings.eigval_cutoff.recip())
323            })
324            .collect_vec();
325
326        let vals = filtered.iter().map(|x| *x.0).collect_vec();
327        let vals = ColRef::from_slice(&vals).to_owned();
328
329        let vecs_vec = filtered.into_iter().map(|x| x.1).collect_vec();
330        let mut vecs = Mat::zeros(subspace_basis.ncols(), vals.nrows());
331        vecs.col_iter_mut()
332            .zip(vecs_vec.iter())
333            .for_each(|(mut col, vals)| col.copy_from(vals));
334
335        let vecs = subspace_basis * vecs;
336        Some((stds, vals, vecs))
337    }
338}
339
340fn rescale_points(draws: &mut Mat<f64>, grads: &mut Mat<f64>) -> Col<f64> {
341    let (ndim, ndraws) = draws.shape();
342
343    Col::from_fn(ndim, |col| {
344        let draw_mean = draws.row(col).sum() / (ndraws as f64);
345        let grad_mean = grads.row(col).sum() / (ndraws as f64);
346        let draw_std: f64 = draws
347            .row(col)
348            .iter()
349            .map(|&val| (val - draw_mean) * (val - draw_mean))
350            .sum::<f64>()
351            .sqrt();
352        let grad_std: f64 = grads
353            .row(col)
354            .iter()
355            .map(|&val| (val - grad_mean) * (val - grad_mean))
356            .sum::<f64>()
357            .sqrt();
358
359        let std = (draw_std / grad_std).sqrt();
360
361        let draw_scale = (std * (ndraws as f64)).recip();
362        draws
363            .row_mut(col)
364            .iter_mut()
365            .for_each(|val| *val = (*val - draw_mean) * draw_scale);
366
367        let grad_scale = std * (ndraws as f64).recip();
368        grads
369            .row_mut(col)
370            .iter_mut()
371            .for_each(|val| *val = (*val - grad_mean) * grad_scale);
372
373        std
374    })
375}
376
377fn estimate_mass_matrix(
378    draws: Mat<f64>,
379    grads: Mat<f64>,
380    gamma: f64,
381) -> Option<(Col<f64>, Mat<f64>)> {
382    let mut cov_draws = (&draws) * draws.transpose();
383    let mut cov_grads = (&grads) * grads.transpose();
384
385    cov_draws *= Scale(gamma.recip());
386    cov_grads *= Scale(gamma.recip());
387
388    cov_draws
389        .diagonal_mut()
390        .column_vector_mut()
391        .iter_mut()
392        .for_each(|x| *x += 1f64);
393
394    cov_grads
395        .diagonal_mut()
396        .column_vector_mut()
397        .iter_mut()
398        .for_each(|x| *x += 1f64);
399
400    let mean = spd_mean(cov_draws, cov_grads)?;
401
402    let mean_eig = mean.self_adjoint_eigen(faer::Side::Lower).ok()?;
403
404    Some((
405        mean_eig.S().column_vector().to_owned(),
406        mean_eig.U().to_owned(),
407    ))
408}
409
410fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Option<Mat<f64>> {
411    let eigs_grads = cov_grads.self_adjoint_eigen(faer::Side::Lower).ok()?;
412
413    let u = eigs_grads.U();
414    let eigs = eigs_grads.S().column_vector().to_owned();
415
416    let mut eigs_sqrt = eigs.clone();
417    eigs_sqrt.iter_mut().for_each(|val| *val = val.sqrt());
418    let cov_grads_sqrt = u * eigs_sqrt.into_diagonal() * u.transpose();
419    let m = (&cov_grads_sqrt) * cov_draws * cov_grads_sqrt;
420
421    let m_eig = m.self_adjoint_eigen(faer::Side::Lower).ok()?;
422
423    let m_u = m_eig.U();
424    let mut m_s = m_eig.S().column_vector().to_owned();
425    m_s.iter_mut().for_each(|val| *val = val.sqrt());
426
427    let m_sqrt = m_u * m_s.into_diagonal() * m_u.transpose();
428
429    let mut eigs_grads_inv = eigs;
430    eigs_grads_inv
431        .iter_mut()
432        .for_each(|val| *val = val.sqrt().recip());
433    let grads_inv_sqrt = u * eigs_grads_inv.into_diagonal() * u.transpose();
434
435    Some((&grads_inv_sqrt) * m_sqrt * grads_inv_sqrt)
436}
437
438impl<M: Math> SamplerStats<M> for LowRankMassMatrixStrategy {
439    type Stats = ();
440    type StatsOptions = ();
441
442    fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {}
443}
444
445impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
446    type MassMatrix = LowRankMassMatrix<M>;
447    type Collector = DrawGradCollector<M>;
448    type Options = LowRankSettings;
449
450    fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self {
451        Self::new(math.dim(), options)
452    }
453
454    fn init<R: rand::Rng + ?Sized>(
455        &mut self,
456        math: &mut M,
457        _options: &mut crate::nuts::NutsOptions,
458        mass_matrix: &mut Self::MassMatrix,
459        point: &impl Point<M>,
460        _rng: &mut R,
461    ) -> Result<(), NutsError> {
462        self.add_draw(math, point);
463        mass_matrix.update_from_grad(math, point.gradient(), 1f64, (1e-20, 1e20));
464        Ok(())
465    }
466
467    fn new_collector(&self, math: &mut M) -> Self::Collector {
468        DrawGradCollector::new(math)
469    }
470
471    fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector) {
472        if collector.is_good {
473            let mut draw = vec![0f64; self.ndim];
474            math.write_to_slice(&collector.draw, &mut draw);
475            self.draws.push_back(draw);
476
477            let mut grad = vec![0f64; self.ndim];
478            math.write_to_slice(&collector.grad, &mut grad);
479            self.grads.push_back(grad);
480        }
481    }
482
483    fn switch(&mut self, _math: &mut M) {
484        for _ in 0..self.background_split {
485            self.draws.pop_front().expect("Could not drop draw");
486            self.grads.pop_front().expect("Could not drop gradient");
487        }
488        self.background_split = self.draws.len();
489        assert!(self.draws.len() == self.grads.len());
490    }
491
492    fn current_count(&self) -> u64 {
493        self.draws.len() as u64
494    }
495
496    fn background_count(&self) -> u64 {
497        self.draws.len().checked_sub(self.background_split).unwrap() as u64
498    }
499
500    fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool {
501        if <LowRankMassMatrixStrategy as MassMatrixAdaptStrategy<M>>::current_count(self) < 3 {
502            return false;
503        }
504        self.update(math, mass_matrix);
505
506        true
507    }
508}
509
510#[cfg(test)]
511mod test {
512    use std::ops::AddAssign;
513
514    use equator::Cmp;
515    use faer::{Col, Mat, utils::approx::ApproxEq};
516    use rand::{Rng, SeedableRng, rngs::SmallRng};
517    use rand_distr::StandardNormal;
518
519    use super::{estimate_mass_matrix, mat_all_finite, spd_mean};
520
521    #[test]
522    fn test_spd_mean() {
523        let x_diag = faer::col![1., 4., 8.];
524        let y_diag = faer::col![1., 1., 0.5];
525
526        let mut x = faer::Mat::zeros(3, 3);
527        let mut y = faer::Mat::zeros(3, 3);
528
529        x.diagonal_mut().column_vector_mut().add_assign(x_diag);
530        y.diagonal_mut().column_vector_mut().add_assign(y_diag);
531
532        let out = spd_mean(x, y).expect("Failed to compute spd mean");
533        let expected_diag = faer::col![1., 2., 4.];
534        let mut expected = faer::Mat::zeros(3, 3);
535        expected
536            .diagonal_mut()
537            .column_vector_mut()
538            .add_assign(expected_diag);
539
540        let comp = ApproxEq {
541            abs_tol: 1e-10,
542            rel_tol: 1e-10,
543        };
544
545        faer::zip!(&out, &expected).for_each(|faer::unzip!(out, expected)| {
546            comp.test(out, expected).unwrap();
547        });
548    }
549
550    #[test]
551    fn test_estimate_mass_matrix() {
552        let distr = StandardNormal;
553
554        let mut rng = SmallRng::seed_from_u64(1);
555
556        let draws: Mat<f64> = Mat::from_fn(20, 3, |_, _| rng.sample(distr));
557        //let grads: Mat<f64> = Mat::from_fn(20, 3, |_, _| rng.sample(distr));
558        let grads = -(&draws);
559
560        let (vals, vecs) =
561            estimate_mass_matrix(draws, grads, 0.0001).expect("Failed to compute mass matrix");
562        assert!(vals.iter().cloned().all(|x| x > 0.));
563        assert!(mat_all_finite(&vecs.as_ref()));
564
565        let comp = ApproxEq {
566            abs_tol: 1e-5,
567            rel_tol: 1e-5,
568        };
569
570        let expected = Col::full(20, 1.);
571
572        faer::zip!(&vals, &expected).for_each(|faer::unzip!(out, expected)| {
573            comp.test(out, expected).unwrap();
574        });
575    }
576}