Skip to main content

nuts_rs/
external_adapt_strategy.rs

1//! Adaptation strategy for when the coordinate transformation is learned from data rather than computed analytically.
2
3use nuts_derive::Storable;
4use nuts_storable::{HasDims, Storable};
5use serde::{Deserialize, Serialize};
6
7use crate::adapt_strategy::CombinedCollector;
8use crate::chain::AdaptStrategy;
9use crate::dynamics::{Hamiltonian, Point, State, TransformedHamiltonian, TransformedPoint};
10use crate::nuts::{Collector, NutsOptions, SampleInfo};
11use crate::sampler_stats::{SamplerStats, StatsDims};
12use crate::stepsize::AcceptanceRateCollector;
13use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
14use crate::transform::ExternalTransformation;
15use crate::{Math, NutsError};
16
17#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
18pub struct FlowSettings {
19    pub step_size_window: f64,
20    pub transform_update_freq: u64,
21    pub use_orbit_for_training: bool,
22    pub step_size_settings: StepSizeSettings,
23    pub transform_train_max_energy_error: f64,
24}
25
26/// Backwards-compatible alias for [`FlowSettings`].
27#[deprecated(since = "0.0.0", note = "Use FlowSettings instead")]
28pub type TransformedSettings = FlowSettings;
29
30impl Default for FlowSettings {
31    fn default() -> Self {
32        Self {
33            step_size_window: 0.07f64,
34            transform_update_freq: 128,
35            use_orbit_for_training: false,
36            transform_train_max_energy_error: 20f64,
37            step_size_settings: Default::default(),
38        }
39    }
40}
41
42pub struct ExternalTransformAdaptation {
43    step_size: StepSizeStrategy,
44    options: FlowSettings,
45    num_tune: u64,
46    final_window_size: u64,
47    tuning: bool,
48    chain: u64,
49}
50
51#[derive(Debug, Storable)]
52pub struct Stats<P: HasDims, S: Storable<P>> {
53    tuning: bool,
54    #[storable(flatten)]
55    pub step_size: S,
56    #[storable(ignore)]
57    _phantom: std::marker::PhantomData<fn() -> P>,
58}
59
60impl<M: Math> SamplerStats<M> for ExternalTransformAdaptation {
61    type Stats = Stats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats>;
62    type StatsOptions = ();
63
64    fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
65        Stats {
66            tuning: self.tuning,
67            step_size: { self.step_size.extract_stats(math, ()) },
68            _phantom: std::marker::PhantomData,
69        }
70    }
71}
72
73pub struct DrawCollector<M: Math> {
74    draws: Vec<M::Vector>,
75    grads: Vec<M::Vector>,
76    logps: Vec<f64>,
77    collect_orbit: bool,
78    max_energy_error: f64,
79}
80
81impl<M: Math> DrawCollector<M> {
82    fn new(_math: &mut M, collect_orbit: bool, max_energy_error: f64) -> Self {
83        Self {
84            draws: vec![],
85            grads: vec![],
86            logps: vec![],
87            collect_orbit,
88            max_energy_error,
89        }
90    }
91}
92
93impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
94    fn register_leapfrog(
95        &mut self,
96        math: &mut M,
97        _start: &State<M, P>,
98        end: &State<M, P>,
99        divergence_info: Option<&crate::DivergenceInfo>,
100    ) {
101        if divergence_info.is_some() {
102            return;
103        }
104
105        if self.collect_orbit {
106            let point = end.point();
107            let energy_error = point.energy_error();
108            if !energy_error.is_finite() {
109                return;
110            }
111
112            if energy_error > self.max_energy_error {
113                return;
114            }
115
116            if !math.array_all_finite(point.position()) {
117                return;
118            }
119            if !math.array_all_finite(point.gradient()) {
120                return;
121            }
122
123            self.draws.push(math.copy_array(point.position()));
124            self.grads.push(math.copy_array(point.gradient()));
125            self.logps.push(point.logp());
126        }
127    }
128
129    fn register_draw(&mut self, math: &mut M, state: &State<M, P>, _info: &SampleInfo) {
130        if !self.collect_orbit {
131            let point = state.point();
132            let energy_error = point.energy_error();
133            if !energy_error.is_finite() {
134                return;
135            }
136
137            if energy_error > self.max_energy_error {
138                return;
139            }
140
141            if !math.array_all_finite(point.position()) {
142                return;
143            }
144            if !math.array_all_finite(point.gradient()) {
145                return;
146            }
147
148            self.draws.push(math.copy_array(point.position()));
149            self.grads.push(math.copy_array(point.gradient()));
150            self.logps.push(point.logp());
151        }
152    }
153}
154
155impl<M: Math> AdaptStrategy<M> for ExternalTransformAdaptation {
156    type Hamiltonian = TransformedHamiltonian<M, ExternalTransformation<M>>;
157
158    type Collector =
159        CombinedCollector<M, TransformedPoint<M>, AcceptanceRateCollector, DrawCollector<M>>;
160
161    type Options = FlowSettings;
162
163    fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
164        let step_size = StepSizeStrategy::new(options.step_size_settings);
165        let final_window_size =
166            ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64;
167        Self {
168            step_size,
169            options,
170            num_tune,
171            final_window_size,
172            tuning: true,
173            chain,
174        }
175    }
176
177    fn init<R: rand::Rng + ?Sized>(
178        &mut self,
179        math: &mut M,
180        options: &mut NutsOptions,
181        hamiltonian: &mut Self::Hamiltonian,
182        position: &[f64],
183        rng: &mut R,
184    ) -> Result<(), NutsError> {
185        hamiltonian.init_transformation(rng, math, position, self.chain)?;
186        self.step_size
187            .init(math, options, hamiltonian, position, rng)?;
188        Ok(())
189    }
190
191    fn adapt<R: rand::Rng + ?Sized>(
192        &mut self,
193        math: &mut M,
194        _options: &mut NutsOptions,
195        hamiltonian: &mut Self::Hamiltonian,
196        draw: u64,
197        collector: &Self::Collector,
198        _state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
199        rng: &mut R,
200    ) -> Result<(), NutsError> {
201        self.step_size.update(&collector.collector1);
202
203        if draw >= self.num_tune {
204            // Needed for step size jitter
205            self.step_size.update_stepsize(rng, hamiltonian, true);
206            self.tuning = false;
207            return Ok(());
208        }
209
210        if draw < self.final_window_size {
211            if draw < 100 {
212                if (draw > 0) && draw.is_multiple_of(10) {
213                    hamiltonian.update_params(
214                        math,
215                        rng,
216                        collector.collector2.draws.iter(),
217                        collector.collector2.grads.iter(),
218                        collector.collector2.logps.iter(),
219                    )?;
220                }
221            } else if (draw > 0) && draw.is_multiple_of(self.options.transform_update_freq) {
222                hamiltonian.update_params(
223                    math,
224                    rng,
225                    collector.collector2.draws.iter(),
226                    collector.collector2.grads.iter(),
227                    collector.collector2.logps.iter(),
228                )?;
229            }
230            self.step_size.update_estimator_early();
231            self.step_size.update_stepsize(rng, hamiltonian, false);
232            return Ok(());
233        }
234
235        self.step_size.update_estimator_late();
236        let is_last = draw == self.num_tune - 1;
237        self.step_size.update_stepsize(rng, hamiltonian, is_last);
238        Ok(())
239    }
240
241    fn new_collector(&self, math: &mut M) -> Self::Collector {
242        Self::Collector::new(
243            self.step_size.new_collector(),
244            DrawCollector::new(
245                math,
246                self.options.use_orbit_for_training,
247                self.options.transform_train_max_energy_error,
248            ),
249        )
250    }
251
252    fn is_tuning(&self) -> bool {
253        self.tuning
254    }
255
256    fn last_num_steps(&self) -> u64 {
257        self.step_size.last_n_steps
258    }
259}