nuts_rs/
transform_adapt_strategy.rs

1use nuts_derive::Storable;
2use nuts_storable::{HasDims, Storable};
3use serde::Serialize;
4
5use crate::adapt_strategy::CombinedCollector;
6use crate::chain::AdaptStrategy;
7use crate::hamiltonian::{Hamiltonian, Point};
8use crate::nuts::{Collector, NutsOptions, SampleInfo};
9use crate::sampler_stats::{SamplerStats, StatsDims};
10use crate::state::State;
11use crate::stepsize::AcceptanceRateCollector;
12use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
13use crate::transformed_hamiltonian::TransformedHamiltonian;
14use crate::{Math, NutsError};
15
16#[derive(Clone, Copy, Debug, Serialize)]
17pub struct TransformedSettings {
18    pub step_size_window: f64,
19    pub transform_update_freq: u64,
20    pub use_orbit_for_training: bool,
21    pub step_size_settings: StepSizeSettings,
22    pub transform_train_max_energy_error: f64,
23}
24
25impl Default for TransformedSettings {
26    fn default() -> Self {
27        Self {
28            step_size_window: 0.07f64,
29            transform_update_freq: 128,
30            use_orbit_for_training: false,
31            transform_train_max_energy_error: 20f64,
32            step_size_settings: Default::default(),
33        }
34    }
35}
36
37pub struct TransformAdaptation {
38    step_size: StepSizeStrategy,
39    options: TransformedSettings,
40    num_tune: u64,
41    final_window_size: u64,
42    tuning: bool,
43    chain: u64,
44}
45
46#[derive(Debug, Storable)]
47pub struct Stats<P: HasDims, S: Storable<P>> {
48    tuning: bool,
49    #[storable(flatten)]
50    pub step_size: S,
51    #[storable(ignore)]
52    _phantom: std::marker::PhantomData<fn() -> P>,
53}
54
55impl<M: Math> SamplerStats<M> for TransformAdaptation {
56    type Stats = Stats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats>;
57    type StatsOptions = ();
58
59    fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
60        Stats {
61            tuning: self.tuning,
62            step_size: { self.step_size.extract_stats(math, ()) },
63            _phantom: std::marker::PhantomData,
64        }
65    }
66}
67
68pub struct DrawCollector<M: Math> {
69    draws: Vec<M::Vector>,
70    grads: Vec<M::Vector>,
71    logps: Vec<f64>,
72    collect_orbit: bool,
73    max_energy_error: f64,
74}
75
76impl<M: Math> DrawCollector<M> {
77    fn new(_math: &mut M, collect_orbit: bool, max_energy_error: f64) -> Self {
78        Self {
79            draws: vec![],
80            grads: vec![],
81            logps: vec![],
82            collect_orbit,
83            max_energy_error,
84        }
85    }
86}
87
88impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
89    fn register_leapfrog(
90        &mut self,
91        math: &mut M,
92        _start: &State<M, P>,
93        end: &State<M, P>,
94        divergence_info: Option<&crate::DivergenceInfo>,
95    ) {
96        if divergence_info.is_some() {
97            return;
98        }
99
100        if self.collect_orbit {
101            let point = end.point();
102            let energy_error = point.energy_error();
103            if !energy_error.is_finite() {
104                return;
105            }
106
107            if energy_error > self.max_energy_error {
108                return;
109            }
110
111            if !math.array_all_finite(point.position()) {
112                return;
113            }
114            if !math.array_all_finite(point.gradient()) {
115                return;
116            }
117
118            self.draws.push(math.copy_array(point.position()));
119            self.grads.push(math.copy_array(point.gradient()));
120            self.logps.push(point.logp());
121        }
122    }
123
124    fn register_draw(&mut self, math: &mut M, state: &State<M, P>, _info: &SampleInfo) {
125        if !self.collect_orbit {
126            let point = state.point();
127            let energy_error = point.energy_error();
128            if !energy_error.is_finite() {
129                return;
130            }
131
132            if energy_error > self.max_energy_error {
133                return;
134            }
135
136            if !math.array_all_finite(point.position()) {
137                return;
138            }
139            if !math.array_all_finite(point.gradient()) {
140                return;
141            }
142
143            self.draws.push(math.copy_array(point.position()));
144            self.grads.push(math.copy_array(point.gradient()));
145            self.logps.push(point.logp());
146        }
147    }
148}
149
150impl<M: Math> AdaptStrategy<M> for TransformAdaptation {
151    type Hamiltonian = TransformedHamiltonian<M>;
152
153    type Collector = CombinedCollector<
154        M,
155        <Self::Hamiltonian as Hamiltonian<M>>::Point,
156        AcceptanceRateCollector,
157        DrawCollector<M>,
158    >;
159
160    type Options = TransformedSettings;
161
162    fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
163        let step_size = StepSizeStrategy::new(options.step_size_settings);
164        let final_window_size =
165            ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64;
166        Self {
167            step_size,
168            options,
169            num_tune,
170            final_window_size,
171            tuning: true,
172            chain,
173        }
174    }
175
176    fn init<R: rand::Rng + ?Sized>(
177        &mut self,
178        math: &mut M,
179        options: &mut NutsOptions,
180        hamiltonian: &mut Self::Hamiltonian,
181        position: &[f64],
182        rng: &mut R,
183    ) -> Result<(), NutsError> {
184        hamiltonian.init_transformation(rng, math, position, self.chain)?;
185        self.step_size
186            .init(math, options, hamiltonian, position, rng)?;
187        Ok(())
188    }
189
190    fn adapt<R: rand::Rng + ?Sized>(
191        &mut self,
192        math: &mut M,
193        _options: &mut NutsOptions,
194        hamiltonian: &mut Self::Hamiltonian,
195        draw: u64,
196        collector: &Self::Collector,
197        _state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
198        rng: &mut R,
199    ) -> Result<(), NutsError> {
200        self.step_size.update(&collector.collector1);
201
202        if draw >= self.num_tune {
203            // Needed for step size jitter
204            self.step_size.update_stepsize(rng, hamiltonian, true);
205            self.tuning = false;
206            return Ok(());
207        }
208
209        if draw < self.final_window_size {
210            if draw < 100 {
211                if (draw > 0) && draw.is_multiple_of(10) {
212                    hamiltonian.update_params(
213                        math,
214                        rng,
215                        collector.collector2.draws.iter(),
216                        collector.collector2.grads.iter(),
217                        collector.collector2.logps.iter(),
218                    )?;
219                }
220            } else if (draw > 0) && draw.is_multiple_of(self.options.transform_update_freq) {
221                hamiltonian.update_params(
222                    math,
223                    rng,
224                    collector.collector2.draws.iter(),
225                    collector.collector2.grads.iter(),
226                    collector.collector2.logps.iter(),
227                )?;
228            }
229            self.step_size.update_estimator_early();
230            self.step_size.update_stepsize(rng, hamiltonian, false);
231            return Ok(());
232        }
233
234        self.step_size.update_estimator_late();
235        let is_last = draw == self.num_tune - 1;
236        self.step_size.update_stepsize(rng, hamiltonian, is_last);
237        Ok(())
238    }
239
240    fn new_collector(&self, math: &mut M) -> Self::Collector {
241        Self::Collector::new(
242            self.step_size.new_collector(),
243            DrawCollector::new(
244                math,
245                self.options.use_orbit_for_training,
246                self.options.transform_train_max_energy_error,
247            ),
248        )
249    }
250
251    fn is_tuning(&self) -> bool {
252        self.tuning
253    }
254
255    fn last_num_steps(&self) -> u64 {
256        self.step_size.last_n_steps
257    }
258}