1use std::collections::VecDeque;
2use std::iter::repeat;
3
4use faer::{Col, ColRef, Mat, MatRef, Scale};
5use itertools::Itertools;
6use nuts_derive::Storable;
7use serde::Serialize;
8
9use super::adapt::MassMatrixAdaptStrategy;
10use super::diagonal::{DrawGradCollector, MassMatrix};
11use crate::{
12 Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point,
13 sampler_stats::SamplerStats,
14};
15
16fn mat_all_finite(mat: &MatRef<f64>) -> bool {
17 let mut ok = true;
18 faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
19 ok
20}
21
22fn col_all_finite(mat: &ColRef<f64>) -> bool {
23 let mut ok = true;
24 faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
25 ok
26}
27
28#[derive(Debug)]
29struct InnerMatrix<M: Math> {
30 vecs: M::EigVectors,
31 vals: M::EigValues,
32 vals_sqrt_inv: M::EigValues,
33 num_eigenvalues: u64,
34}
35
36impl<M: Math> InnerMatrix<M> {
37 fn new(math: &mut M, mut vals: Col<f64>, vecs: Mat<f64>) -> Self {
38 let vecs = math.new_eig_vectors(
39 vecs.col_iter()
40 .map(|col| col.try_as_col_major().unwrap().as_slice()),
41 );
42 let vals_math = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
43
44 vals.iter_mut().for_each(|x| *x = x.sqrt().recip());
45 let vals_inv_math = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
46
47 Self {
48 vecs,
49 vals: vals_math,
50 vals_sqrt_inv: vals_inv_math,
51 num_eigenvalues: vals.nrows() as u64,
52 }
53 }
54}
55
56#[derive(Debug)]
57pub struct LowRankMassMatrix<M: Math> {
58 variance: M::Vector,
59 stds: M::Vector,
60 inv_stds: M::Vector,
61 inner: Option<InnerMatrix<M>>,
62 settings: LowRankSettings,
63}
64
65impl<M: Math> LowRankMassMatrix<M> {
66 pub fn new(math: &mut M, settings: LowRankSettings) -> Self {
67 Self {
68 variance: math.new_array(),
69 inv_stds: math.new_array(),
70 stds: math.new_array(),
71 settings,
72 inner: None,
73 }
74 }
75
76 fn update_from_grad(
77 &mut self,
78 math: &mut M,
79 grad: &<M as Math>::Vector,
80 fill_invalid: f64,
81 clamp: (f64, f64),
82 ) {
83 math.array_update_var_inv_std_grad(
84 &mut self.variance,
85 &mut self.inv_stds,
86 grad,
87 fill_invalid,
88 clamp,
89 );
90 let mut vals = vec![0f64; math.dim()];
91 math.write_to_slice(&self.inv_stds, &mut vals);
92 vals.iter_mut().for_each(|x| *x = x.recip());
93 math.read_from_slice(&mut self.stds, &vals);
94 }
95
96 fn update(&mut self, math: &mut M, mut stds: Col<f64>, vals: Col<f64>, vecs: Mat<f64>) {
97 math.read_from_slice(&mut self.stds, stds.try_as_col_major().unwrap().as_slice());
98
99 stds.iter_mut().for_each(|x| *x = x.recip());
100 math.read_from_slice(
101 &mut self.inv_stds,
102 stds.try_as_col_major().unwrap().as_slice(),
103 );
104
105 stds.iter_mut().for_each(|x| *x = x.recip() * x.recip());
106 math.read_from_slice(
107 &mut self.variance,
108 stds.try_as_col_major().unwrap().as_slice(),
109 );
110
111 if col_all_finite(&vals.as_ref()) & mat_all_finite(&vecs.as_ref()) {
112 self.inner = Some(InnerMatrix::new(math, vals, vecs));
113 } else {
114 self.inner = None;
115 }
116 }
117}
118
119#[derive(Clone, Debug, Copy, Serialize)]
120pub struct LowRankSettings {
121 pub store_mass_matrix: bool,
122 pub gamma: f64,
123 pub eigval_cutoff: f64,
124}
125
126impl Default for LowRankSettings {
127 fn default() -> Self {
128 Self {
129 store_mass_matrix: false,
130 gamma: 1e-5,
131 eigval_cutoff: 2f64,
132 }
133 }
134}
135
136#[derive(Debug, Storable)]
137pub struct MatrixStats {
138 #[storable(dims("unconstrained_parameter"))]
139 pub mass_matrix_eigvals: Option<Vec<f64>>,
140 #[storable(dims("unconstrained_parameter"))]
141 pub mass_matrix_stds: Option<Vec<f64>>,
142 pub num_eigenvalues: u64,
143}
144
145impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
146 type Stats = MatrixStats;
147 type StatsOptions = ();
148
149 fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
150 if self.settings.store_mass_matrix {
151 let stds = Some(math.box_array(&self.stds));
152 let eigvals = self
153 .inner
154 .as_ref()
155 .map(|inner| math.eigs_as_array(&inner.vals));
156 let mut eigvals = eigvals.map(|x| x.into_vec());
157 if let Some(ref mut eigvals) = eigvals {
158 eigvals.extend(repeat(f64::NAN).take(stds.as_ref().unwrap().len() - eigvals.len()));
159 }
160 MatrixStats {
161 mass_matrix_eigvals: eigvals,
162 mass_matrix_stds: stds.map(|x| x.into_vec()),
163 num_eigenvalues: self
164 .inner
165 .as_ref()
166 .map(|inner| inner.num_eigenvalues)
167 .unwrap_or(0),
168 }
169 } else {
170 MatrixStats {
171 mass_matrix_eigvals: None,
172 mass_matrix_stds: None,
173 num_eigenvalues: self
174 .inner
175 .as_ref()
176 .map(|inner| inner.num_eigenvalues)
177 .unwrap_or(0),
178 }
179 }
180 }
181}
182
183impl<M: Math> MassMatrix<M> for LowRankMassMatrix<M> {
184 fn update_velocity(&self, math: &mut M, state: &mut EuclideanPoint<M>) {
185 let Some(inner) = self.inner.as_ref() else {
186 math.array_mult(&self.variance, &state.momentum, &mut state.velocity);
187 return;
188 };
189
190 math.array_mult_eigs(
191 &self.stds,
192 &state.momentum,
193 &mut state.velocity,
194 &inner.vecs,
195 &inner.vals,
196 );
197 }
198
199 fn update_kinetic_energy(&self, math: &mut M, state: &mut EuclideanPoint<M>) {
200 state.kinetic_energy = 0.5 * math.array_vector_dot(&state.momentum, &state.velocity);
201 }
202
203 fn randomize_momentum<R: rand::Rng + ?Sized>(
204 &self,
205 math: &mut M,
206 state: &mut EuclideanPoint<M>,
207 rng: &mut R,
208 ) {
209 let Some(inner) = self.inner.as_ref() else {
210 math.array_gaussian(rng, &mut state.momentum, &self.inv_stds);
211 return;
212 };
213
214 math.array_gaussian_eigs(
215 rng,
216 &mut state.momentum,
217 &self.inv_stds,
218 &inner.vals_sqrt_inv,
219 &inner.vecs,
220 );
221 }
222}
223
224#[derive(Debug)]
237pub struct LowRankMassMatrixStrategy {
238 draws: VecDeque<Vec<f64>>,
239 grads: VecDeque<Vec<f64>>,
240 ndim: usize,
241 background_split: usize,
242 settings: LowRankSettings,
243}
244
245impl LowRankMassMatrixStrategy {
246 pub fn new(ndim: usize, settings: LowRankSettings) -> Self {
247 let draws = VecDeque::with_capacity(100);
248 let grads = VecDeque::with_capacity(100);
249
250 Self {
251 draws,
252 grads,
253 ndim,
254 background_split: 0,
255 settings,
256 }
257 }
258
259 pub fn add_draw<M: Math>(&mut self, math: &mut M, point: &impl Point<M>) {
260 assert!(math.dim() == self.ndim);
261 let mut draw = vec![0f64; self.ndim];
262 math.write_to_slice(point.position(), &mut draw);
263 let mut grad = vec![0f64; self.ndim];
264 math.write_to_slice(point.gradient(), &mut grad);
265
266 self.draws.push_back(draw);
267 self.grads.push_back(grad);
268 }
269
270 pub fn clear(&mut self) {
271 self.draws.clear();
272 self.grads.clear();
273 }
274
275 pub fn update<M: Math>(&self, math: &mut M, matrix: &mut LowRankMassMatrix<M>) {
276 let draws_vec = &self.draws;
277 let grads_vec = &self.grads;
278
279 let ndraws = draws_vec.len();
280 assert!(grads_vec.len() == ndraws);
281
282 let mut draws: Mat<f64> = Mat::zeros(self.ndim, ndraws);
283 let mut grads: Mat<f64> = Mat::zeros(self.ndim, ndraws);
284
285 for (i, (draw, grad)) in draws_vec.iter().zip(grads_vec.iter()).enumerate() {
286 draws.col_as_slice_mut(i).copy_from_slice(&draw[..]);
287 grads.col_as_slice_mut(i).copy_from_slice(&grad[..]);
288 }
289
290 let Some((stds, vals, vecs)) = self.compute_update(draws, grads) else {
291 return;
292 };
293
294 matrix.update(math, stds, vals, vecs);
295 }
296
297 fn compute_update(
298 &self,
299 mut draws: Mat<f64>,
300 mut grads: Mat<f64>,
301 ) -> Option<(Col<f64>, Col<f64>, Mat<f64>)> {
302 let stds = rescale_points(&mut draws, &mut grads);
303
304 let svd_draws = draws.thin_svd().ok()?;
305 let svd_grads = grads.thin_svd().ok()?;
306
307 let subspace = faer::concat![[svd_draws.U(), svd_grads.U()]];
308
309 let subspace_qr = subspace.col_piv_qr();
310
311 let subspace_basis = subspace_qr.compute_thin_Q();
312
313 let draws_proj = subspace_basis.transpose() * (&draws);
314 let grads_proj = subspace_basis.transpose() * (&grads);
315
316 let (vals, vecs) = estimate_mass_matrix(draws_proj, grads_proj, self.settings.gamma)?;
317
318 let filtered = vals
319 .iter()
320 .zip(vecs.col_iter())
321 .filter(|&(&val, _)| {
322 (val > self.settings.eigval_cutoff) | (val < self.settings.eigval_cutoff.recip())
323 })
324 .collect_vec();
325
326 let vals = filtered.iter().map(|x| *x.0).collect_vec();
327 let vals = ColRef::from_slice(&vals).to_owned();
328
329 let vecs_vec = filtered.into_iter().map(|x| x.1).collect_vec();
330 let mut vecs = Mat::zeros(subspace_basis.ncols(), vals.nrows());
331 vecs.col_iter_mut()
332 .zip(vecs_vec.iter())
333 .for_each(|(mut col, vals)| col.copy_from(vals));
334
335 let vecs = subspace_basis * vecs;
336 Some((stds, vals, vecs))
337 }
338}
339
340fn rescale_points(draws: &mut Mat<f64>, grads: &mut Mat<f64>) -> Col<f64> {
341 let (ndim, ndraws) = draws.shape();
342
343 Col::from_fn(ndim, |col| {
344 let draw_mean = draws.row(col).sum() / (ndraws as f64);
345 let grad_mean = grads.row(col).sum() / (ndraws as f64);
346 let draw_std: f64 = draws
347 .row(col)
348 .iter()
349 .map(|&val| (val - draw_mean) * (val - draw_mean))
350 .sum::<f64>()
351 .sqrt();
352 let grad_std: f64 = grads
353 .row(col)
354 .iter()
355 .map(|&val| (val - grad_mean) * (val - grad_mean))
356 .sum::<f64>()
357 .sqrt();
358
359 let std = (draw_std / grad_std).sqrt();
360
361 let draw_scale = (std * (ndraws as f64)).recip();
362 draws
363 .row_mut(col)
364 .iter_mut()
365 .for_each(|val| *val = (*val - draw_mean) * draw_scale);
366
367 let grad_scale = std * (ndraws as f64).recip();
368 grads
369 .row_mut(col)
370 .iter_mut()
371 .for_each(|val| *val = (*val - grad_mean) * grad_scale);
372
373 std
374 })
375}
376
377fn estimate_mass_matrix(
378 draws: Mat<f64>,
379 grads: Mat<f64>,
380 gamma: f64,
381) -> Option<(Col<f64>, Mat<f64>)> {
382 let mut cov_draws = (&draws) * draws.transpose();
383 let mut cov_grads = (&grads) * grads.transpose();
384
385 cov_draws *= Scale(gamma.recip());
386 cov_grads *= Scale(gamma.recip());
387
388 cov_draws
389 .diagonal_mut()
390 .column_vector_mut()
391 .iter_mut()
392 .for_each(|x| *x += 1f64);
393
394 cov_grads
395 .diagonal_mut()
396 .column_vector_mut()
397 .iter_mut()
398 .for_each(|x| *x += 1f64);
399
400 let mean = spd_mean(cov_draws, cov_grads)?;
401
402 let mean_eig = mean.self_adjoint_eigen(faer::Side::Lower).ok()?;
403
404 Some((
405 mean_eig.S().column_vector().to_owned(),
406 mean_eig.U().to_owned(),
407 ))
408}
409
410fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Option<Mat<f64>> {
411 let eigs_grads = cov_grads.self_adjoint_eigen(faer::Side::Lower).ok()?;
412
413 let u = eigs_grads.U();
414 let eigs = eigs_grads.S().column_vector().to_owned();
415
416 let mut eigs_sqrt = eigs.clone();
417 eigs_sqrt.iter_mut().for_each(|val| *val = val.sqrt());
418 let cov_grads_sqrt = u * eigs_sqrt.into_diagonal() * u.transpose();
419 let m = (&cov_grads_sqrt) * cov_draws * cov_grads_sqrt;
420
421 let m_eig = m.self_adjoint_eigen(faer::Side::Lower).ok()?;
422
423 let m_u = m_eig.U();
424 let mut m_s = m_eig.S().column_vector().to_owned();
425 m_s.iter_mut().for_each(|val| *val = val.sqrt());
426
427 let m_sqrt = m_u * m_s.into_diagonal() * m_u.transpose();
428
429 let mut eigs_grads_inv = eigs;
430 eigs_grads_inv
431 .iter_mut()
432 .for_each(|val| *val = val.sqrt().recip());
433 let grads_inv_sqrt = u * eigs_grads_inv.into_diagonal() * u.transpose();
434
435 Some((&grads_inv_sqrt) * m_sqrt * grads_inv_sqrt)
436}
437
438impl<M: Math> SamplerStats<M> for LowRankMassMatrixStrategy {
439 type Stats = ();
440 type StatsOptions = ();
441
442 fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {}
443}
444
445impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
446 type MassMatrix = LowRankMassMatrix<M>;
447 type Collector = DrawGradCollector<M>;
448 type Options = LowRankSettings;
449
450 fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self {
451 Self::new(math.dim(), options)
452 }
453
454 fn init<R: rand::Rng + ?Sized>(
455 &mut self,
456 math: &mut M,
457 _options: &mut crate::nuts::NutsOptions,
458 mass_matrix: &mut Self::MassMatrix,
459 point: &impl Point<M>,
460 _rng: &mut R,
461 ) -> Result<(), NutsError> {
462 self.add_draw(math, point);
463 mass_matrix.update_from_grad(math, point.gradient(), 1f64, (1e-20, 1e20));
464 Ok(())
465 }
466
467 fn new_collector(&self, math: &mut M) -> Self::Collector {
468 DrawGradCollector::new(math)
469 }
470
471 fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector) {
472 if collector.is_good {
473 let mut draw = vec![0f64; self.ndim];
474 math.write_to_slice(&collector.draw, &mut draw);
475 self.draws.push_back(draw);
476
477 let mut grad = vec![0f64; self.ndim];
478 math.write_to_slice(&collector.grad, &mut grad);
479 self.grads.push_back(grad);
480 }
481 }
482
483 fn switch(&mut self, _math: &mut M) {
484 for _ in 0..self.background_split {
485 self.draws.pop_front().expect("Could not drop draw");
486 self.grads.pop_front().expect("Could not drop gradient");
487 }
488 self.background_split = self.draws.len();
489 assert!(self.draws.len() == self.grads.len());
490 }
491
492 fn current_count(&self) -> u64 {
493 self.draws.len() as u64
494 }
495
496 fn background_count(&self) -> u64 {
497 self.draws.len().checked_sub(self.background_split).unwrap() as u64
498 }
499
500 fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool {
501 if <LowRankMassMatrixStrategy as MassMatrixAdaptStrategy<M>>::current_count(self) < 3 {
502 return false;
503 }
504 self.update(math, mass_matrix);
505
506 true
507 }
508}
509
510#[cfg(test)]
511mod test {
512 use std::ops::AddAssign;
513
514 use equator::Cmp;
515 use faer::{Col, Mat, utils::approx::ApproxEq};
516 use rand::{Rng, SeedableRng, rngs::SmallRng};
517 use rand_distr::StandardNormal;
518
519 use super::{estimate_mass_matrix, mat_all_finite, spd_mean};
520
521 #[test]
522 fn test_spd_mean() {
523 let x_diag = faer::col![1., 4., 8.];
524 let y_diag = faer::col![1., 1., 0.5];
525
526 let mut x = faer::Mat::zeros(3, 3);
527 let mut y = faer::Mat::zeros(3, 3);
528
529 x.diagonal_mut().column_vector_mut().add_assign(x_diag);
530 y.diagonal_mut().column_vector_mut().add_assign(y_diag);
531
532 let out = spd_mean(x, y).expect("Failed to compute spd mean");
533 let expected_diag = faer::col![1., 2., 4.];
534 let mut expected = faer::Mat::zeros(3, 3);
535 expected
536 .diagonal_mut()
537 .column_vector_mut()
538 .add_assign(expected_diag);
539
540 let comp = ApproxEq {
541 abs_tol: 1e-10,
542 rel_tol: 1e-10,
543 };
544
545 faer::zip!(&out, &expected).for_each(|faer::unzip!(out, expected)| {
546 comp.test(out, expected).unwrap();
547 });
548 }
549
550 #[test]
551 fn test_estimate_mass_matrix() {
552 let distr = StandardNormal;
553
554 let mut rng = SmallRng::seed_from_u64(1);
555
556 let draws: Mat<f64> = Mat::from_fn(20, 3, |_, _| rng.sample(distr));
557 let grads = -(&draws);
559
560 let (vals, vecs) =
561 estimate_mass_matrix(draws, grads, 0.0001).expect("Failed to compute mass matrix");
562 assert!(vals.iter().cloned().all(|x| x > 0.));
563 assert!(mat_all_finite(&vecs.as_ref()));
564
565 let comp = ApproxEq {
566 abs_tol: 1e-5,
567 rel_tol: 1e-5,
568 };
569
570 let expected = Col::full(20, 1.);
571
572 faer::zip!(&vals, &expected).for_each(|faer::unzip!(out, expected)| {
573 comp.test(out, expected).unwrap();
574 });
575 }
576}