1use nuts_derive::Storable;
2use nuts_storable::{HasDims, Storable};
3use serde::Serialize;
4
5use crate::adapt_strategy::CombinedCollector;
6use crate::chain::AdaptStrategy;
7use crate::hamiltonian::{Hamiltonian, Point};
8use crate::nuts::{Collector, NutsOptions, SampleInfo};
9use crate::sampler_stats::{SamplerStats, StatsDims};
10use crate::state::State;
11use crate::stepsize::AcceptanceRateCollector;
12use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
13use crate::transformed_hamiltonian::TransformedHamiltonian;
14use crate::{Math, NutsError};
15
16#[derive(Clone, Copy, Debug, Serialize)]
17pub struct TransformedSettings {
18 pub step_size_window: f64,
19 pub transform_update_freq: u64,
20 pub use_orbit_for_training: bool,
21 pub step_size_settings: StepSizeSettings,
22 pub transform_train_max_energy_error: f64,
23}
24
25impl Default for TransformedSettings {
26 fn default() -> Self {
27 Self {
28 step_size_window: 0.07f64,
29 transform_update_freq: 128,
30 use_orbit_for_training: false,
31 transform_train_max_energy_error: 20f64,
32 step_size_settings: Default::default(),
33 }
34 }
35}
36
37pub struct TransformAdaptation {
38 step_size: StepSizeStrategy,
39 options: TransformedSettings,
40 num_tune: u64,
41 final_window_size: u64,
42 tuning: bool,
43 chain: u64,
44}
45
46#[derive(Debug, Storable)]
47pub struct Stats<P: HasDims, S: Storable<P>> {
48 tuning: bool,
49 #[storable(flatten)]
50 pub step_size: S,
51 #[storable(ignore)]
52 _phantom: std::marker::PhantomData<fn() -> P>,
53}
54
55impl<M: Math> SamplerStats<M> for TransformAdaptation {
56 type Stats = Stats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats>;
57 type StatsOptions = ();
58
59 fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
60 Stats {
61 tuning: self.tuning,
62 step_size: { self.step_size.extract_stats(math, ()) },
63 _phantom: std::marker::PhantomData,
64 }
65 }
66}
67
68pub struct DrawCollector<M: Math> {
69 draws: Vec<M::Vector>,
70 grads: Vec<M::Vector>,
71 logps: Vec<f64>,
72 collect_orbit: bool,
73 max_energy_error: f64,
74}
75
76impl<M: Math> DrawCollector<M> {
77 fn new(_math: &mut M, collect_orbit: bool, max_energy_error: f64) -> Self {
78 Self {
79 draws: vec![],
80 grads: vec![],
81 logps: vec![],
82 collect_orbit,
83 max_energy_error,
84 }
85 }
86}
87
88impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
89 fn register_leapfrog(
90 &mut self,
91 math: &mut M,
92 _start: &State<M, P>,
93 end: &State<M, P>,
94 divergence_info: Option<&crate::DivergenceInfo>,
95 ) {
96 if divergence_info.is_some() {
97 return;
98 }
99
100 if self.collect_orbit {
101 let point = end.point();
102 let energy_error = point.energy_error();
103 if !energy_error.is_finite() {
104 return;
105 }
106
107 if energy_error > self.max_energy_error {
108 return;
109 }
110
111 if !math.array_all_finite(point.position()) {
112 return;
113 }
114 if !math.array_all_finite(point.gradient()) {
115 return;
116 }
117
118 self.draws.push(math.copy_array(point.position()));
119 self.grads.push(math.copy_array(point.gradient()));
120 self.logps.push(point.logp());
121 }
122 }
123
124 fn register_draw(&mut self, math: &mut M, state: &State<M, P>, _info: &SampleInfo) {
125 if !self.collect_orbit {
126 let point = state.point();
127 let energy_error = point.energy_error();
128 if !energy_error.is_finite() {
129 return;
130 }
131
132 if energy_error > self.max_energy_error {
133 return;
134 }
135
136 if !math.array_all_finite(point.position()) {
137 return;
138 }
139 if !math.array_all_finite(point.gradient()) {
140 return;
141 }
142
143 self.draws.push(math.copy_array(point.position()));
144 self.grads.push(math.copy_array(point.gradient()));
145 self.logps.push(point.logp());
146 }
147 }
148}
149
150impl<M: Math> AdaptStrategy<M> for TransformAdaptation {
151 type Hamiltonian = TransformedHamiltonian<M>;
152
153 type Collector = CombinedCollector<
154 M,
155 <Self::Hamiltonian as Hamiltonian<M>>::Point,
156 AcceptanceRateCollector,
157 DrawCollector<M>,
158 >;
159
160 type Options = TransformedSettings;
161
162 fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
163 let step_size = StepSizeStrategy::new(options.step_size_settings);
164 let final_window_size =
165 ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64;
166 Self {
167 step_size,
168 options,
169 num_tune,
170 final_window_size,
171 tuning: true,
172 chain,
173 }
174 }
175
176 fn init<R: rand::Rng + ?Sized>(
177 &mut self,
178 math: &mut M,
179 options: &mut NutsOptions,
180 hamiltonian: &mut Self::Hamiltonian,
181 position: &[f64],
182 rng: &mut R,
183 ) -> Result<(), NutsError> {
184 hamiltonian.init_transformation(rng, math, position, self.chain)?;
185 self.step_size
186 .init(math, options, hamiltonian, position, rng)?;
187 Ok(())
188 }
189
190 fn adapt<R: rand::Rng + ?Sized>(
191 &mut self,
192 math: &mut M,
193 _options: &mut NutsOptions,
194 hamiltonian: &mut Self::Hamiltonian,
195 draw: u64,
196 collector: &Self::Collector,
197 _state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
198 rng: &mut R,
199 ) -> Result<(), NutsError> {
200 self.step_size.update(&collector.collector1);
201
202 if draw >= self.num_tune {
203 self.step_size.update_stepsize(rng, hamiltonian, true);
205 self.tuning = false;
206 return Ok(());
207 }
208
209 if draw < self.final_window_size {
210 if draw < 100 {
211 if (draw > 0) && draw.is_multiple_of(10) {
212 hamiltonian.update_params(
213 math,
214 rng,
215 collector.collector2.draws.iter(),
216 collector.collector2.grads.iter(),
217 collector.collector2.logps.iter(),
218 )?;
219 }
220 } else if (draw > 0) && draw.is_multiple_of(self.options.transform_update_freq) {
221 hamiltonian.update_params(
222 math,
223 rng,
224 collector.collector2.draws.iter(),
225 collector.collector2.grads.iter(),
226 collector.collector2.logps.iter(),
227 )?;
228 }
229 self.step_size.update_estimator_early();
230 self.step_size.update_stepsize(rng, hamiltonian, false);
231 return Ok(());
232 }
233
234 self.step_size.update_estimator_late();
235 let is_last = draw == self.num_tune - 1;
236 self.step_size.update_stepsize(rng, hamiltonian, is_last);
237 Ok(())
238 }
239
240 fn new_collector(&self, math: &mut M) -> Self::Collector {
241 Self::Collector::new(
242 self.step_size.new_collector(),
243 DrawCollector::new(
244 math,
245 self.options.use_orbit_for_training,
246 self.options.transform_train_max_energy_error,
247 ),
248 )
249 }
250
251 fn is_tuning(&self) -> bool {
252 self.tuning
253 }
254
255 fn last_num_steps(&self) -> u64 {
256 self.step_size.last_n_steps
257 }
258}