1use std::marker::PhantomData;
4
5use nuts_derive::Storable;
6use rand::Rng;
7use serde::{Deserialize, Serialize};
8
9use crate::{
10 Math, NutsError, SamplerStats,
11 dynamics::{Point, State},
12 nuts::{Collector, NutsOptions},
13 transform::{DiagMassMatrix, adapt::strategy::MassMatrixAdaptStrategy},
14};
15
16#[derive(Debug)]
17pub struct RunningVariance<M: Math> {
18 mean: M::Vector,
19 variance: M::Vector,
20 count: u64,
21}
22
23impl<M: Math> RunningVariance<M> {
24 pub(crate) fn new(math: &mut M) -> Self {
25 Self {
26 mean: math.new_array(),
27 variance: math.new_array(),
28 count: 0,
29 }
30 }
31
32 pub(crate) fn add_sample(&mut self, math: &mut M, value: &M::Vector) {
33 self.count += 1;
34 if self.count == 1 {
35 math.copy_into(value, &mut self.mean);
36 } else {
37 math.array_update_variance(
38 &mut self.mean,
39 &mut self.variance,
40 value,
41 (self.count as f64).recip(),
42 );
43 }
44 }
45
46 pub(crate) fn current(&self) -> (&M::Vector, f64) {
48 assert!(self.count > 1);
49 (&self.variance, ((self.count - 1) as f64).recip())
50 }
51
52 pub(crate) fn count(&self) -> u64 {
53 self.count
54 }
55}
56
57pub struct DrawGradCollector<M: Math> {
58 pub(crate) draw: M::Vector,
59 pub(crate) grad: M::Vector,
60 pub(crate) is_good: bool,
61}
62
63impl<M: Math> DrawGradCollector<M> {
64 pub(crate) fn new(math: &mut M) -> Self {
65 DrawGradCollector {
66 draw: math.new_array(),
67 grad: math.new_array(),
68 is_good: true,
69 }
70 }
71}
72
73impl<M: Math, P: Point<M>> Collector<M, P> for DrawGradCollector<M> {
74 fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
75 math.copy_into(state.point().position(), &mut self.draw);
76 math.copy_into(state.point().gradient(), &mut self.grad);
77 let idx = state.index_in_trajectory();
78 if info.divergence_info.is_some() {
79 self.is_good = idx.abs() > 4;
80 } else {
81 self.is_good = idx != 0;
82 }
83 }
84}
85
86const LOWER_LIMIT: f64 = 1e-20f64;
87const UPPER_LIMIT: f64 = 1e20f64;
88
89const INIT_LOWER_LIMIT: f64 = 1e-20f64;
90const INIT_UPPER_LIMIT: f64 = 1e20f64;
91
92#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
94pub struct DiagAdaptExpSettings {
95 pub store_mass_matrix: bool,
96 pub use_grad_based_estimate: bool,
97}
98
99impl Default for DiagAdaptExpSettings {
100 fn default() -> Self {
101 Self {
102 store_mass_matrix: false,
103 use_grad_based_estimate: true,
104 }
105 }
106}
107
108pub struct Strategy<M: Math> {
109 exp_variance_draw: RunningVariance<M>,
110 exp_variance_grad: RunningVariance<M>,
111 exp_variance_grad_bg: RunningVariance<M>,
112 exp_variance_draw_bg: RunningVariance<M>,
113 _settings: DiagAdaptExpSettings,
114 _phantom: PhantomData<M>,
115}
116
117#[derive(Debug, Storable)]
118pub struct Stats {}
119
120impl<M: Math> SamplerStats<M> for Strategy<M> {
121 type Stats = Stats;
122 type StatsOptions = ();
123
124 fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
125 Stats {}
126 }
127}
128
129impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
130 type Transformation = DiagMassMatrix<M>;
131 type Collector = DrawGradCollector<M>;
132 type Options = DiagAdaptExpSettings;
133
134 fn update_estimators(&mut self, math: &mut M, collector: &DrawGradCollector<M>) {
135 if collector.is_good {
136 self.exp_variance_draw.add_sample(math, &collector.draw);
137 self.exp_variance_grad.add_sample(math, &collector.grad);
138 self.exp_variance_draw_bg.add_sample(math, &collector.draw);
139 self.exp_variance_grad_bg.add_sample(math, &collector.grad);
140 }
141 }
142
143 fn switch(&mut self, math: &mut M) {
144 self.exp_variance_draw =
145 std::mem::replace(&mut self.exp_variance_draw_bg, RunningVariance::new(math));
146 self.exp_variance_grad =
147 std::mem::replace(&mut self.exp_variance_grad_bg, RunningVariance::new(math));
148 }
149
150 fn current_count(&self) -> u64 {
151 assert!(self.exp_variance_draw.count() == self.exp_variance_grad.count());
152 self.exp_variance_draw.count()
153 }
154
155 fn background_count(&self) -> u64 {
156 assert!(self.exp_variance_draw_bg.count() == self.exp_variance_grad_bg.count());
157 self.exp_variance_draw_bg.count()
158 }
159
160 fn adapt(&self, math: &mut M, mass_matrix: &mut DiagMassMatrix<M>) -> bool {
162 if self.current_count() < 3 {
163 return false;
164 }
165
166 let (draw_var, draw_scale) = self.exp_variance_draw.current();
167 let (grad_var, grad_scale) = self.exp_variance_grad.current();
168 assert!(draw_scale == grad_scale);
169
170 let draw_mean = &self.exp_variance_draw.mean;
171 let grad_mean = &self.exp_variance_grad.mean;
172
173 if self._settings.use_grad_based_estimate {
174 mass_matrix.update_diag_draw_grad(
175 math,
176 draw_mean,
177 grad_mean,
178 draw_var,
179 grad_var,
180 None,
181 (LOWER_LIMIT, UPPER_LIMIT),
182 );
183 } else {
184 let scale = (self.exp_variance_draw.count() as f64).recip();
185 mass_matrix.update_diag_draw(
186 math,
187 draw_mean,
188 draw_var,
189 scale,
190 None,
191 (LOWER_LIMIT, UPPER_LIMIT),
192 );
193 }
194
195 true
196 }
197
198 fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self {
199 Self {
200 exp_variance_draw: RunningVariance::new(math),
201 exp_variance_grad: RunningVariance::new(math),
202 exp_variance_draw_bg: RunningVariance::new(math),
203 exp_variance_grad_bg: RunningVariance::new(math),
204 _settings: options,
205 _phantom: PhantomData,
206 }
207 }
208
209 fn init<R: Rng + ?Sized>(
210 &mut self,
211 math: &mut M,
212 _options: &mut NutsOptions,
213 mass_matrix: &mut Self::Transformation,
214 point: &impl Point<M>,
215 _rng: &mut R,
216 ) -> Result<(), NutsError> {
217 self.exp_variance_draw.add_sample(math, point.position());
218 self.exp_variance_draw_bg.add_sample(math, point.position());
219 self.exp_variance_grad.add_sample(math, point.gradient());
220 self.exp_variance_grad_bg.add_sample(math, point.gradient());
221
222 mass_matrix.update_diag_grad(
223 math,
224 point.position(),
225 point.gradient(),
226 1f64,
227 (INIT_LOWER_LIMIT, INIT_UPPER_LIMIT),
228 );
229
230 Ok(())
231 }
232
233 fn new_collector(&self, math: &mut M) -> Self::Collector {
234 DrawGradCollector::new(math)
235 }
236}