Skip to main content

nuts_rs/transform/adapt/
diagonal.rs

1//! Online estimator that adapts the diagonal mass matrix from draw and gradient variance during warmup.
2
3use std::marker::PhantomData;
4
5use nuts_derive::Storable;
6use rand::Rng;
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    Math, NutsError, SamplerStats,
11    dynamics::{Point, State},
12    nuts::{Collector, NutsOptions},
13    transform::{DiagMassMatrix, adapt::strategy::MassMatrixAdaptStrategy},
14};
15
16#[derive(Debug)]
17pub struct RunningVariance<M: Math> {
18    mean: M::Vector,
19    variance: M::Vector,
20    count: u64,
21}
22
23impl<M: Math> RunningVariance<M> {
24    pub(crate) fn new(math: &mut M) -> Self {
25        Self {
26            mean: math.new_array(),
27            variance: math.new_array(),
28            count: 0,
29        }
30    }
31
32    pub(crate) fn add_sample(&mut self, math: &mut M, value: &M::Vector) {
33        self.count += 1;
34        if self.count == 1 {
35            math.copy_into(value, &mut self.mean);
36        } else {
37            math.array_update_variance(
38                &mut self.mean,
39                &mut self.variance,
40                value,
41                (self.count as f64).recip(),
42            );
43        }
44    }
45
46    /// Return current variance and scaling factor
47    pub(crate) fn current(&self) -> (&M::Vector, f64) {
48        assert!(self.count > 1);
49        (&self.variance, ((self.count - 1) as f64).recip())
50    }
51
52    pub(crate) fn count(&self) -> u64 {
53        self.count
54    }
55}
56
57pub struct DrawGradCollector<M: Math> {
58    pub(crate) draw: M::Vector,
59    pub(crate) grad: M::Vector,
60    pub(crate) is_good: bool,
61}
62
63impl<M: Math> DrawGradCollector<M> {
64    pub(crate) fn new(math: &mut M) -> Self {
65        DrawGradCollector {
66            draw: math.new_array(),
67            grad: math.new_array(),
68            is_good: true,
69        }
70    }
71}
72
73impl<M: Math, P: Point<M>> Collector<M, P> for DrawGradCollector<M> {
74    fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
75        math.copy_into(state.point().position(), &mut self.draw);
76        math.copy_into(state.point().gradient(), &mut self.grad);
77        let idx = state.index_in_trajectory();
78        if info.divergence_info.is_some() {
79            self.is_good = idx.abs() > 4;
80        } else {
81            self.is_good = idx != 0;
82        }
83    }
84}
85
86const LOWER_LIMIT: f64 = 1e-20f64;
87const UPPER_LIMIT: f64 = 1e20f64;
88
89const INIT_LOWER_LIMIT: f64 = 1e-20f64;
90const INIT_UPPER_LIMIT: f64 = 1e20f64;
91
92/// Settings for mass matrix adaptation
93#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
94pub struct DiagAdaptExpSettings {
95    pub store_mass_matrix: bool,
96    pub use_grad_based_estimate: bool,
97}
98
99impl Default for DiagAdaptExpSettings {
100    fn default() -> Self {
101        Self {
102            store_mass_matrix: false,
103            use_grad_based_estimate: true,
104        }
105    }
106}
107
108pub struct Strategy<M: Math> {
109    exp_variance_draw: RunningVariance<M>,
110    exp_variance_grad: RunningVariance<M>,
111    exp_variance_grad_bg: RunningVariance<M>,
112    exp_variance_draw_bg: RunningVariance<M>,
113    _settings: DiagAdaptExpSettings,
114    _phantom: PhantomData<M>,
115}
116
117#[derive(Debug, Storable)]
118pub struct Stats {}
119
120impl<M: Math> SamplerStats<M> for Strategy<M> {
121    type Stats = Stats;
122    type StatsOptions = ();
123
124    fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
125        Stats {}
126    }
127}
128
129impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
130    type Transformation = DiagMassMatrix<M>;
131    type Collector = DrawGradCollector<M>;
132    type Options = DiagAdaptExpSettings;
133
134    fn update_estimators(&mut self, math: &mut M, collector: &DrawGradCollector<M>) {
135        if collector.is_good {
136            self.exp_variance_draw.add_sample(math, &collector.draw);
137            self.exp_variance_grad.add_sample(math, &collector.grad);
138            self.exp_variance_draw_bg.add_sample(math, &collector.draw);
139            self.exp_variance_grad_bg.add_sample(math, &collector.grad);
140        }
141    }
142
143    fn switch(&mut self, math: &mut M) {
144        self.exp_variance_draw =
145            std::mem::replace(&mut self.exp_variance_draw_bg, RunningVariance::new(math));
146        self.exp_variance_grad =
147            std::mem::replace(&mut self.exp_variance_grad_bg, RunningVariance::new(math));
148    }
149
150    fn current_count(&self) -> u64 {
151        assert!(self.exp_variance_draw.count() == self.exp_variance_grad.count());
152        self.exp_variance_draw.count()
153    }
154
155    fn background_count(&self) -> u64 {
156        assert!(self.exp_variance_draw_bg.count() == self.exp_variance_grad_bg.count());
157        self.exp_variance_draw_bg.count()
158    }
159
160    /// Give the opportunity to update the potential and return if it was changed
161    fn adapt(&self, math: &mut M, mass_matrix: &mut DiagMassMatrix<M>) -> bool {
162        if self.current_count() < 3 {
163            return false;
164        }
165
166        let (draw_var, draw_scale) = self.exp_variance_draw.current();
167        let (grad_var, grad_scale) = self.exp_variance_grad.current();
168        assert!(draw_scale == grad_scale);
169
170        let draw_mean = &self.exp_variance_draw.mean;
171        let grad_mean = &self.exp_variance_grad.mean;
172
173        if self._settings.use_grad_based_estimate {
174            mass_matrix.update_diag_draw_grad(
175                math,
176                draw_mean,
177                grad_mean,
178                draw_var,
179                grad_var,
180                None,
181                (LOWER_LIMIT, UPPER_LIMIT),
182            );
183        } else {
184            let scale = (self.exp_variance_draw.count() as f64).recip();
185            mass_matrix.update_diag_draw(
186                math,
187                draw_mean,
188                draw_var,
189                scale,
190                None,
191                (LOWER_LIMIT, UPPER_LIMIT),
192            );
193        }
194
195        true
196    }
197
198    fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self {
199        Self {
200            exp_variance_draw: RunningVariance::new(math),
201            exp_variance_grad: RunningVariance::new(math),
202            exp_variance_draw_bg: RunningVariance::new(math),
203            exp_variance_grad_bg: RunningVariance::new(math),
204            _settings: options,
205            _phantom: PhantomData,
206        }
207    }
208
209    fn init<R: Rng + ?Sized>(
210        &mut self,
211        math: &mut M,
212        _options: &mut NutsOptions,
213        mass_matrix: &mut Self::Transformation,
214        point: &impl Point<M>,
215        _rng: &mut R,
216    ) -> Result<(), NutsError> {
217        self.exp_variance_draw.add_sample(math, point.position());
218        self.exp_variance_draw_bg.add_sample(math, point.position());
219        self.exp_variance_grad.add_sample(math, point.gradient());
220        self.exp_variance_grad_bg.add_sample(math, point.gradient());
221
222        mass_matrix.update_diag_grad(
223            math,
224            point.position(),
225            point.gradient(),
226            1f64,
227            (INIT_LOWER_LIMIT, INIT_UPPER_LIMIT),
228        );
229
230        Ok(())
231    }
232
233    fn new_collector(&self, math: &mut M) -> Self::Collector {
234        DrawGradCollector::new(math)
235    }
236}