1use nuts_derive::Storable;
4use nuts_storable::{HasDims, Storable};
5use serde::{Deserialize, Serialize};
6
7use crate::adapt_strategy::CombinedCollector;
8use crate::chain::AdaptStrategy;
9use crate::dynamics::{Hamiltonian, Point, State, TransformedHamiltonian, TransformedPoint};
10use crate::nuts::{Collector, NutsOptions, SampleInfo};
11use crate::sampler_stats::{SamplerStats, StatsDims};
12use crate::stepsize::AcceptanceRateCollector;
13use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
14use crate::transform::ExternalTransformation;
15use crate::{Math, NutsError};
16
17#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
18pub struct FlowSettings {
19 pub step_size_window: f64,
20 pub transform_update_freq: u64,
21 pub use_orbit_for_training: bool,
22 pub step_size_settings: StepSizeSettings,
23 pub transform_train_max_energy_error: f64,
24}
25
26#[deprecated(since = "0.0.0", note = "Use FlowSettings instead")]
28pub type TransformedSettings = FlowSettings;
29
30impl Default for FlowSettings {
31 fn default() -> Self {
32 Self {
33 step_size_window: 0.07f64,
34 transform_update_freq: 128,
35 use_orbit_for_training: false,
36 transform_train_max_energy_error: 20f64,
37 step_size_settings: Default::default(),
38 }
39 }
40}
41
42pub struct ExternalTransformAdaptation {
43 step_size: StepSizeStrategy,
44 options: FlowSettings,
45 num_tune: u64,
46 final_window_size: u64,
47 tuning: bool,
48 chain: u64,
49}
50
51#[derive(Debug, Storable)]
52pub struct Stats<P: HasDims, S: Storable<P>> {
53 tuning: bool,
54 #[storable(flatten)]
55 pub step_size: S,
56 #[storable(ignore)]
57 _phantom: std::marker::PhantomData<fn() -> P>,
58}
59
60impl<M: Math> SamplerStats<M> for ExternalTransformAdaptation {
61 type Stats = Stats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats>;
62 type StatsOptions = ();
63
64 fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
65 Stats {
66 tuning: self.tuning,
67 step_size: { self.step_size.extract_stats(math, ()) },
68 _phantom: std::marker::PhantomData,
69 }
70 }
71}
72
73pub struct DrawCollector<M: Math> {
74 draws: Vec<M::Vector>,
75 grads: Vec<M::Vector>,
76 logps: Vec<f64>,
77 collect_orbit: bool,
78 max_energy_error: f64,
79}
80
81impl<M: Math> DrawCollector<M> {
82 fn new(_math: &mut M, collect_orbit: bool, max_energy_error: f64) -> Self {
83 Self {
84 draws: vec![],
85 grads: vec![],
86 logps: vec![],
87 collect_orbit,
88 max_energy_error,
89 }
90 }
91}
92
93impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
94 fn register_leapfrog(
95 &mut self,
96 math: &mut M,
97 _start: &State<M, P>,
98 end: &State<M, P>,
99 divergence_info: Option<&crate::DivergenceInfo>,
100 ) {
101 if divergence_info.is_some() {
102 return;
103 }
104
105 if self.collect_orbit {
106 let point = end.point();
107 let energy_error = point.energy_error();
108 if !energy_error.is_finite() {
109 return;
110 }
111
112 if energy_error > self.max_energy_error {
113 return;
114 }
115
116 if !math.array_all_finite(point.position()) {
117 return;
118 }
119 if !math.array_all_finite(point.gradient()) {
120 return;
121 }
122
123 self.draws.push(math.copy_array(point.position()));
124 self.grads.push(math.copy_array(point.gradient()));
125 self.logps.push(point.logp());
126 }
127 }
128
129 fn register_draw(&mut self, math: &mut M, state: &State<M, P>, _info: &SampleInfo) {
130 if !self.collect_orbit {
131 let point = state.point();
132 let energy_error = point.energy_error();
133 if !energy_error.is_finite() {
134 return;
135 }
136
137 if energy_error > self.max_energy_error {
138 return;
139 }
140
141 if !math.array_all_finite(point.position()) {
142 return;
143 }
144 if !math.array_all_finite(point.gradient()) {
145 return;
146 }
147
148 self.draws.push(math.copy_array(point.position()));
149 self.grads.push(math.copy_array(point.gradient()));
150 self.logps.push(point.logp());
151 }
152 }
153}
154
155impl<M: Math> AdaptStrategy<M> for ExternalTransformAdaptation {
156 type Hamiltonian = TransformedHamiltonian<M, ExternalTransformation<M>>;
157
158 type Collector =
159 CombinedCollector<M, TransformedPoint<M>, AcceptanceRateCollector, DrawCollector<M>>;
160
161 type Options = FlowSettings;
162
163 fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
164 let step_size = StepSizeStrategy::new(options.step_size_settings);
165 let final_window_size =
166 ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64;
167 Self {
168 step_size,
169 options,
170 num_tune,
171 final_window_size,
172 tuning: true,
173 chain,
174 }
175 }
176
177 fn init<R: rand::Rng + ?Sized>(
178 &mut self,
179 math: &mut M,
180 options: &mut NutsOptions,
181 hamiltonian: &mut Self::Hamiltonian,
182 position: &[f64],
183 rng: &mut R,
184 ) -> Result<(), NutsError> {
185 hamiltonian.init_transformation(rng, math, position, self.chain)?;
186 self.step_size
187 .init(math, options, hamiltonian, position, rng)?;
188 Ok(())
189 }
190
191 fn adapt<R: rand::Rng + ?Sized>(
192 &mut self,
193 math: &mut M,
194 _options: &mut NutsOptions,
195 hamiltonian: &mut Self::Hamiltonian,
196 draw: u64,
197 collector: &Self::Collector,
198 _state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
199 rng: &mut R,
200 ) -> Result<(), NutsError> {
201 self.step_size.update(&collector.collector1);
202
203 if draw >= self.num_tune {
204 self.step_size.update_stepsize(rng, hamiltonian, true);
206 self.tuning = false;
207 return Ok(());
208 }
209
210 if draw < self.final_window_size {
211 if draw < 100 {
212 if (draw > 0) && draw.is_multiple_of(10) {
213 hamiltonian.update_params(
214 math,
215 rng,
216 collector.collector2.draws.iter(),
217 collector.collector2.grads.iter(),
218 collector.collector2.logps.iter(),
219 )?;
220 }
221 } else if (draw > 0) && draw.is_multiple_of(self.options.transform_update_freq) {
222 hamiltonian.update_params(
223 math,
224 rng,
225 collector.collector2.draws.iter(),
226 collector.collector2.grads.iter(),
227 collector.collector2.logps.iter(),
228 )?;
229 }
230 self.step_size.update_estimator_early();
231 self.step_size.update_stepsize(rng, hamiltonian, false);
232 return Ok(());
233 }
234
235 self.step_size.update_estimator_late();
236 let is_last = draw == self.num_tune - 1;
237 self.step_size.update_stepsize(rng, hamiltonian, is_last);
238 Ok(())
239 }
240
241 fn new_collector(&self, math: &mut M) -> Self::Collector {
242 Self::Collector::new(
243 self.step_size.new_collector(),
244 DrawCollector::new(
245 math,
246 self.options.use_orbit_for_training,
247 self.options.transform_train_max_energy_error,
248 ),
249 )
250 }
251
252 fn is_tuning(&self) -> bool {
253 self.tuning
254 }
255
256 fn last_num_steps(&self) -> u64 {
257 self.step_size.last_n_steps
258 }
259}