1use std::{
15 cell::{Ref, RefCell},
16 fmt::Debug,
17 marker::PhantomData,
18 ops::DerefMut,
19};
20
21use anyhow::{Result, bail};
22use nuts_derive::Storable;
23use nuts_storable::{HasDims, Storable};
24use rand_distr::num_traits::ToPrimitive;
25use serde::{Deserialize, Serialize};
26
27use crate::{
28 Math, NutsError,
29 chain::{AdaptStrategy, Chain, StatOptions},
30 dynamics::{
31 Direction, DivergenceInfo, DivergenceStats, Hamiltonian, KineticEnergyKind, Point, State,
32 TransformedHamiltonian, TransformedPoint,
33 },
34 nuts::{Collector, NutsOptions},
35 sampler::Progress,
36 sampler_stats::{SamplerStats, StatsDims},
37 transform::Transformation,
38};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
45pub enum MclmcTrajectoryKind {
46 #[default]
52 Microcanonical,
53
54 Euclidean,
61
62 EuclideanEarlyThenMicrocanonical,
70}
71
72#[derive(Debug, Clone)]
76pub struct MclmcInfo {
77 pub energy_change: f64,
79 pub diverging: bool,
81 pub divergence_info: Option<DivergenceInfo>,
83 pub num_steps: u64,
85 pub average_step_size: f64,
86}
87
88#[derive(Debug, Storable)]
92pub struct MclmcStats<P: HasDims, H: Storable<P>, A: Storable<P>, Pt: Storable<P>> {
93 pub chain: u64,
94 pub draw: u64,
95 pub num_steps: u64,
96 pub energy_change: f64,
97 pub log_weight: f64,
107 pub tuning: bool,
108 #[storable(flatten)]
109 pub hamiltonian: H,
110 #[storable(flatten)]
111 pub adapt: A,
112 #[storable(flatten)]
113 pub point: Pt,
114 pub average_step_size: f64,
115 #[storable(flatten)]
116 pub divergence: DivergenceStats,
117 #[storable(ignore)]
118 _phantom: PhantomData<fn() -> P>,
119}
120
121pub struct MclmcChain<M, R, A, T>
125where
126 M: Math,
127 R: rand::Rng,
128 T: Transformation<M>,
129 A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
130{
131 hamiltonian: TransformedHamiltonian<M, T>,
132 collector: A::Collector,
133 adapt: A,
134 state: State<M, TransformedPoint<M>>,
135 rng: R,
136 chain: u64,
137 draw_count: u64,
138 subsample_frequency: f64,
146 dynamic_step_size: bool,
152 trajectory_kind: MclmcTrajectoryKind,
154 switch_draw: u64,
158 max_energy_error: f64,
159 nuts_options: NutsOptions,
162 math: RefCell<M>,
163 stats_options: StatOptions<M, A>,
164 last_info: Option<MclmcInfo>,
165 tmp_velocity: M::Vector,
166}
167
168impl<M, R, A, T> MclmcChain<M, R, A, T>
169where
170 M: Math,
171 R: rand::Rng,
172 T: Transformation<M>,
173 A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
174{
175 pub fn new(
176 mut math: M,
177 mut hamiltonian: TransformedHamiltonian<M, T>,
178 adapt: A,
179 rng: R,
180 chain: u64,
181 subsample_frequency: f64,
182 dynamic_step_size: bool,
183 trajectory_kind: MclmcTrajectoryKind,
184 switch_draw: u64,
185 max_energy_error: f64,
186 stats_options: StatOptions<M, A>,
187 ) -> Self {
188 let state = hamiltonian.pool().new_state(&mut math);
189 let collector = adapt.new_collector(&mut math);
190 let tmp_velocity = math.new_array();
191 Self {
192 hamiltonian,
193 collector,
194 adapt,
195 state,
196 rng,
197 chain,
198 draw_count: 0,
199 subsample_frequency,
200 dynamic_step_size,
201 trajectory_kind,
202 switch_draw,
203 nuts_options: NutsOptions::default(),
204 math: math.into(),
205 stats_options,
206 last_info: None,
207 tmp_velocity,
208 max_energy_error,
209 }
210 }
211
212 fn mclmc_kernel(
213 &mut self,
214 resample_velocity: bool,
215 ) -> Result<(State<M, TransformedPoint<M>>, MclmcInfo)> {
216 let math = self.math.get_mut();
217
218 let base_step_size = self.hamiltonian.step_size();
219 let num_base_steps: u64 = self
220 .hamiltonian
221 .momentum_decoherence_length()
222 .map(|length| {
223 let num_steps = (self.subsample_frequency * length / base_step_size)
224 .round()
225 .max(1.0)
226 .min(1e6);
227 if !num_steps.is_finite() {
228 bail!("Invalid number of integration steps");
229 }
230 Ok(num_steps as u64)
231 })
232 .unwrap_or(Ok(1))?;
233
234 let max_halvings: u64 = if self.dynamic_step_size { 10 } else { 0 };
243
244 use crate::dynamics::LeapfrogResult;
245
246 let mut current = self.hamiltonian.copy_state(math, &self.state);
248
249 self.hamiltonian.initialize_trajectory(
250 math,
251 &mut current,
252 resample_velocity,
253 &mut self.rng,
254 )?;
255
256 let ones = {
257 let mut ones = math.new_array();
258 math.fill_array(&mut ones, 1.0);
259 ones
260 };
261 let mut momentum_noise = math.new_array();
262 math.array_gaussian(&mut self.rng, &mut momentum_noise, &ones);
263
264 let draw_start_energy = current.point().energy();
267
268 let mut divergence_info: Option<DivergenceInfo> = None;
269 let mut steps_taken = 0u64;
270
271 let mut factor = 1.0_f64;
273
274 let mut remaining_stack: Vec<u64> = Vec::with_capacity(max_halvings.try_into().unwrap());
275
276 let mut remaining = num_base_steps;
277 let mut time = 0.0;
278
279 while remaining > 0 {
280 math.copy_into(¤t.point().velocity, &mut self.tmp_velocity);
283
284 self.hamiltonian.partial_momentum_refresh(
285 math,
286 &mut current,
287 &momentum_noise,
288 &mut self.rng,
289 factor,
290 )?;
291
292 let step_baseline = current.point().energy();
299 match self.hamiltonian.leapfrog(
301 math,
302 ¤t,
303 Direction::Forward,
304 factor,
305 step_baseline,
306 self.max_energy_error * factor / num_base_steps.to_f64().unwrap(),
308 &mut self.collector,
309 ) {
310 LeapfrogResult::Ok(mut next) => {
311 math.array_gaussian(&mut self.rng, &mut momentum_noise, &ones);
312 self.hamiltonian.partial_momentum_refresh(
313 math,
314 &mut next,
315 &momentum_noise,
316 &mut self.rng,
317 factor,
318 )?;
319 math.array_gaussian(&mut self.rng, &mut momentum_noise, &ones);
321 current = next;
322 steps_taken += 1;
323 remaining -= 1;
324 time += factor * base_step_size;
325
326 while remaining == 0 {
327 if let Some(prev_remaining) = remaining_stack.pop() {
328 remaining = prev_remaining - 1;
329 factor *= 2.0;
330 } else {
331 break;
332 }
333 }
334 }
335 LeapfrogResult::Divergence(info) => {
336 if remaining_stack.len() >= max_halvings.try_into().unwrap() {
337 divergence_info = Some(info);
339 break;
340 }
341 factor *= 0.5;
344 remaining_stack.push(remaining);
345 remaining = 2;
346
347 math.copy_into(
349 &self.tmp_velocity,
350 &mut current.try_point_mut().unwrap().velocity,
351 );
352 }
355 LeapfrogResult::Err(e) => {
356 return Err(NutsError::LogpFailure(e.into()).into());
357 }
358 }
359 }
360
361 if divergence_info.is_some() {
362 let mut next_state = self.hamiltonian.copy_state(math, &self.state);
366 self.hamiltonian
367 .initialize_trajectory(math, &mut next_state, true, &mut self.rng)?;
368 let energy_change = current.point().energy() - draw_start_energy;
369 let info = MclmcInfo {
370 energy_change,
371 diverging: true,
372 divergence_info: divergence_info.clone(),
374 num_steps: steps_taken,
375 average_step_size: time / steps_taken.to_f64().unwrap(),
376 };
377 let sample_info = crate::nuts::SampleInfo {
378 depth: steps_taken,
379 divergence_info,
380 reached_maxdepth: false,
381 };
382 self.collector.register_draw(math, ¤t, &sample_info);
383 return Ok((next_state, info));
384 }
385
386 assert!(steps_taken >= num_base_steps);
387
388 let sample_info = crate::nuts::SampleInfo {
390 depth: steps_taken,
391 divergence_info: None,
392 reached_maxdepth: false,
393 };
394 self.collector.register_draw(math, ¤t, &sample_info);
395
396 let energy_change = current.point().energy_error();
399
400 let info = MclmcInfo {
401 energy_change,
402 diverging: false,
403 divergence_info: None,
404 num_steps: steps_taken,
405 average_step_size: time / steps_taken.to_f64().unwrap(),
406 };
407
408 Ok((current, info))
409 }
410}
411
412impl<M, R, A, T> SamplerStats<M> for MclmcChain<M, R, A, T>
415where
416 M: Math,
417 R: rand::Rng,
418 T: Transformation<M>,
419 A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
420{
421 type Stats = MclmcStats<
422 StatsDims,
423 <TransformedHamiltonian<M, T> as SamplerStats<M>>::Stats,
424 A::Stats,
425 <TransformedPoint<M> as SamplerStats<M>>::Stats,
426 >;
427 type StatsOptions = StatOptions<M, A>;
428
429 fn extract_stats(&self, math: &mut M, options: Self::StatsOptions) -> Self::Stats {
430 let info = self
431 .last_info
432 .as_ref()
433 .expect("Sampler has not started yet");
434 let hamiltonian_stats = self.hamiltonian.extract_stats(math, options.hamiltonian);
435 let adapt_stats = self.adapt.extract_stats(math, options.adapt);
436 let point_stats = self.state.point().extract_stats(math, options.point);
437 MclmcStats {
438 chain: self.chain,
439 draw: self.draw_count,
440 num_steps: info.num_steps,
441 energy_change: info.energy_change,
442 log_weight: info.energy_change,
443 tuning: self.adapt.is_tuning(),
444 hamiltonian: hamiltonian_stats,
445 adapt: adapt_stats,
446 point: point_stats,
447 average_step_size: info.average_step_size,
448 divergence: (
449 info.divergence_info.as_ref(),
450 options.divergence,
451 self.draw_count,
452 )
453 .into(),
454 _phantom: PhantomData,
455 }
456 }
457}
458
459impl<M, R, A, T> Chain<M> for MclmcChain<M, R, A, T>
462where
463 M: Math,
464 R: rand::Rng,
465 T: Transformation<M>,
466 A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
467{
468 type AdaptStrategy = A;
469
470 fn set_position(&mut self, position: &[f64]) -> Result<()> {
471 let mut math_ = self.math.borrow_mut();
472 let math = math_.deref_mut();
473 self.adapt.init(
474 math,
475 &mut self.nuts_options,
476 &mut self.hamiltonian,
477 position,
478 &mut self.rng,
479 )?;
480 self.state = self.hamiltonian.init_state(math, position)?;
481 self.hamiltonian
483 .initialize_trajectory(math, &mut self.state, true, &mut self.rng)?;
484 Ok(())
485 }
486
487 fn draw(&mut self) -> Result<(Box<[f64]>, Progress)> {
488 let resample_velocity = if self.trajectory_kind
494 == MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical
495 && self.draw_count == self.switch_draw
496 && self.hamiltonian.kinetic_energy_kind != KineticEnergyKind::Microcanonical
497 {
498 self.hamiltonian
499 .set_kinetic_energy_kind(KineticEnergyKind::Microcanonical);
500 true
501 } else {
502 false
503 };
504
505 let (state, info) = self.mclmc_kernel(resample_velocity)?;
506
507 let position: Box<[f64]> = {
508 let mut math_ = self.math.borrow_mut();
509 let math = math_.deref_mut();
510 let mut pos = vec![0f64; math.dim()];
511 state.write_position(math, &mut pos);
512 pos.into()
513 };
514
515 let progress = Progress {
516 draw: self.draw_count,
517 chain: self.chain,
518 diverging: info.diverging,
519 tuning: self.adapt.is_tuning(),
520 step_size: self.hamiltonian.step_size(),
521 num_steps: info.num_steps,
522 };
523
524 {
527 let mut math_ = self.math.borrow_mut();
528 let math = math_.deref_mut();
529 self.adapt.adapt(
530 math,
531 &mut self.nuts_options,
532 &mut self.hamiltonian,
533 self.draw_count,
534 &self.collector,
535 &state,
536 &mut self.rng,
537 )?;
538 self.collector = self.adapt.new_collector(math);
540 }
541
542 self.draw_count += 1;
543 self.state = state;
544 self.last_info = Some(info);
545 Ok((position, progress))
546 }
547
548 fn dim(&self) -> usize {
549 self.math.borrow().dim()
550 }
551
552 fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)> {
553 let (position, progress) = self.draw()?;
554 let mut math_ = self.math.borrow_mut();
555 let math = math_.deref_mut();
556 let stats = self.extract_stats(math, self.stats_options);
557 self.stats_options.hamiltonian = self
560 .hamiltonian
561 .update_stats_options(math, self.stats_options.hamiltonian);
562 let expanded = math.expand_vector(&mut self.rng, self.state.point().position())?;
563 Ok((position, expanded, stats, progress))
564 }
565
566 fn math(&self) -> Ref<'_, M> {
567 self.math.borrow()
568 }
569}
570
571#[cfg(test)]
574mod tests {
575 use rand::rng;
576
577 use crate::{
578 Chain, DiagMclmcSettings, MclmcSettings, math::test_logps::NormalLogp,
579 math::CpuMath, sampler::Settings,
580 };
581
582 #[test]
583 fn mclmc_draws_normal() {
584 let ndim = 10;
585 let func = NormalLogp::new(ndim, 3.0);
586 let math = CpuMath::new(func);
587
588 let settings = DiagMclmcSettings {
589 step_size: 0.5,
590 momentum_decoherence_length: 3.0,
591 num_tune: 200,
592 num_draws: 500,
593 ..MclmcSettings::default()
594 };
595
596 let mut rng = rng();
597 let mut chain = settings.new_chain(0, math, &mut rng);
598 chain.set_position(&vec![0.0f64; ndim]).unwrap();
599
600 let mut last_pos = vec![0.0f64; ndim];
601 for _ in 0..500 {
602 let (draw, progress) = chain.draw().unwrap();
603 assert!(!progress.diverging, "unexpected divergence");
604 last_pos.copy_from_slice(&draw);
605 }
606
607 let mean: f64 = last_pos.iter().sum::<f64>() / ndim as f64;
609 assert!(
610 (mean - 3.0).abs() < 3.0,
611 "mean {mean} too far from expected 3.0"
612 );
613 }
614
615 #[test]
616 fn mclmc_euclidean_trajectory() {
617 use crate::mclmc::MclmcTrajectoryKind;
618
619 let ndim = 10;
620 let func = NormalLogp::new(ndim, 3.0);
621 let math = CpuMath::new(func);
622
623 let settings = DiagMclmcSettings {
624 step_size: 0.3,
625 momentum_decoherence_length: 3.0,
626 num_tune: 200,
627 num_draws: 500,
628 trajectory_kind: MclmcTrajectoryKind::Euclidean,
629 ..MclmcSettings::default()
630 };
631
632 let mut rng = rng();
633 let mut chain = settings.new_chain(0, math, &mut rng);
634 chain.set_position(&vec![0.0f64; ndim]).unwrap();
635
636 let mut last_pos = vec![0.0f64; ndim];
637 for _ in 0..500 {
638 let (draw, progress) = chain.draw().unwrap();
639 assert!(!progress.diverging, "unexpected divergence");
640 last_pos.copy_from_slice(&draw);
641 }
642
643 let mean: f64 = last_pos.iter().sum::<f64>() / ndim as f64;
644 assert!(
645 (mean - 3.0).abs() < 3.0,
646 "mean {mean} too far from expected 3.0"
647 );
648 }
649
650 #[test]
651 fn mclmc_euclidean_early_then_microcanonical() {
652 use crate::mclmc::MclmcTrajectoryKind;
653
654 let ndim = 10;
655 let func = NormalLogp::new(ndim, 3.0);
656 let math = CpuMath::new(func);
657
658 let settings = DiagMclmcSettings {
659 step_size: 0.5,
660 momentum_decoherence_length: 3.0,
661 num_tune: 200,
662 num_draws: 500,
663 trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
664 trajectory_switch_fraction: 0.3,
665 ..MclmcSettings::default()
666 };
667
668 let mut rng = rng();
669 let mut chain = settings.new_chain(0, math, &mut rng);
670 chain.set_position(&vec![0.0f64; ndim]).unwrap();
671
672 let mut last_pos = vec![0.0f64; ndim];
673 for _ in 0..500 {
674 let (draw, progress) = chain.draw().unwrap();
675 assert!(!progress.diverging, "unexpected divergence");
676 last_pos.copy_from_slice(&draw);
677 }
678
679 let mean: f64 = last_pos.iter().sum::<f64>() / ndim as f64;
680 assert!(
681 (mean - 3.0).abs() < 3.0,
682 "mean {mean} too far from expected 3.0"
683 );
684 }
685}