nuts_rs/mass_matrix/
adapt.rs

1use std::marker::PhantomData;
2
3use nuts_derive::Storable;
4use rand::Rng;
5use serde::Serialize;
6
7use super::diagonal::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance};
8use crate::{
9    Math, NutsError,
10    euclidean_hamiltonian::EuclideanPoint,
11    hamiltonian::Point,
12    nuts::{Collector, NutsOptions},
13    sampler_stats::SamplerStats,
14};
15const LOWER_LIMIT: f64 = 1e-20f64;
16const UPPER_LIMIT: f64 = 1e20f64;
17
18const INIT_LOWER_LIMIT: f64 = 1e-20f64;
19const INIT_UPPER_LIMIT: f64 = 1e20f64;
20
21/// Settings for mass matrix adaptation
22#[derive(Clone, Copy, Debug, Serialize)]
23pub struct DiagAdaptExpSettings {
24    pub store_mass_matrix: bool,
25    pub use_grad_based_estimate: bool,
26}
27
28impl Default for DiagAdaptExpSettings {
29    fn default() -> Self {
30        Self {
31            store_mass_matrix: false,
32            use_grad_based_estimate: true,
33        }
34    }
35}
36
37pub struct Strategy<M: Math> {
38    exp_variance_draw: RunningVariance<M>,
39    exp_variance_grad: RunningVariance<M>,
40    exp_variance_grad_bg: RunningVariance<M>,
41    exp_variance_draw_bg: RunningVariance<M>,
42    _settings: DiagAdaptExpSettings,
43    _phantom: PhantomData<M>,
44}
45
46#[derive(Debug, Storable)]
47pub struct Stats {}
48
49impl<M: Math> SamplerStats<M> for Strategy<M> {
50    type Stats = Stats;
51    type StatsOptions = ();
52
53    fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
54        Stats {}
55    }
56}
57
58pub trait MassMatrixAdaptStrategy<M: Math>: SamplerStats<M> {
59    type MassMatrix: MassMatrix<M>;
60    type Collector: Collector<M, EuclideanPoint<M>>;
61    type Options: std::fmt::Debug + Default + Clone + Send + Sync + Copy;
62
63    fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector);
64
65    fn switch(&mut self, math: &mut M);
66
67    fn current_count(&self) -> u64;
68
69    fn background_count(&self) -> u64;
70
71    /// Give the opportunity to update the potential and return if it was changed
72    fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool;
73
74    fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self;
75
76    fn init<R: Rng + ?Sized>(
77        &mut self,
78        math: &mut M,
79        _options: &mut NutsOptions,
80        mass_matrix: &mut Self::MassMatrix,
81        point: &impl Point<M>,
82        _rng: &mut R,
83    ) -> Result<(), NutsError>;
84
85    fn new_collector(&self, math: &mut M) -> Self::Collector;
86}
87
88impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
89    type MassMatrix = DiagMassMatrix<M>;
90    type Collector = DrawGradCollector<M>;
91    type Options = DiagAdaptExpSettings;
92
93    fn update_estimators(&mut self, math: &mut M, collector: &DrawGradCollector<M>) {
94        if collector.is_good {
95            self.exp_variance_draw.add_sample(math, &collector.draw);
96            self.exp_variance_grad.add_sample(math, &collector.grad);
97            self.exp_variance_draw_bg.add_sample(math, &collector.draw);
98            self.exp_variance_grad_bg.add_sample(math, &collector.grad);
99        }
100    }
101
102    fn switch(&mut self, math: &mut M) {
103        self.exp_variance_draw =
104            std::mem::replace(&mut self.exp_variance_draw_bg, RunningVariance::new(math));
105        self.exp_variance_grad =
106            std::mem::replace(&mut self.exp_variance_grad_bg, RunningVariance::new(math));
107    }
108
109    fn current_count(&self) -> u64 {
110        assert!(self.exp_variance_draw.count() == self.exp_variance_grad.count());
111        self.exp_variance_draw.count()
112    }
113
114    fn background_count(&self) -> u64 {
115        assert!(self.exp_variance_draw_bg.count() == self.exp_variance_grad_bg.count());
116        self.exp_variance_draw_bg.count()
117    }
118
119    /// Give the opportunity to update the potential and return if it was changed
120    fn adapt(&self, math: &mut M, mass_matrix: &mut DiagMassMatrix<M>) -> bool {
121        if self.current_count() < 3 {
122            return false;
123        }
124
125        let (draw_var, draw_scale) = self.exp_variance_draw.current();
126        let (grad_var, grad_scale) = self.exp_variance_grad.current();
127        assert!(draw_scale == grad_scale);
128
129        if self._settings.use_grad_based_estimate {
130            mass_matrix.update_diag_draw_grad(
131                math,
132                draw_var,
133                grad_var,
134                None,
135                (LOWER_LIMIT, UPPER_LIMIT),
136            );
137        } else {
138            let scale = (self.exp_variance_draw.count() as f64).recip();
139            mass_matrix.update_diag_draw(math, draw_var, scale, None, (LOWER_LIMIT, UPPER_LIMIT));
140        }
141
142        true
143    }
144
145    fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self {
146        Self {
147            exp_variance_draw: RunningVariance::new(math),
148            exp_variance_grad: RunningVariance::new(math),
149            exp_variance_draw_bg: RunningVariance::new(math),
150            exp_variance_grad_bg: RunningVariance::new(math),
151            _settings: options,
152            _phantom: PhantomData,
153        }
154    }
155
156    fn init<R: Rng + ?Sized>(
157        &mut self,
158        math: &mut M,
159        _options: &mut NutsOptions,
160        mass_matrix: &mut Self::MassMatrix,
161        point: &impl Point<M>,
162        _rng: &mut R,
163    ) -> Result<(), NutsError> {
164        self.exp_variance_draw.add_sample(math, point.position());
165        self.exp_variance_draw_bg.add_sample(math, point.position());
166        self.exp_variance_grad.add_sample(math, point.gradient());
167        self.exp_variance_grad_bg.add_sample(math, point.gradient());
168
169        mass_matrix.update_diag_grad(
170            math,
171            point.gradient(),
172            1f64,
173            (INIT_LOWER_LIMIT, INIT_UPPER_LIMIT),
174        );
175        Ok(())
176    }
177
178    fn new_collector(&self, math: &mut M) -> Self::Collector {
179        DrawGradCollector::new(math)
180    }
181}