1use anyhow::Result;
5use nuts_storable::{HasDims, Storable, Value};
6use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::{collections::HashMap, fmt::Debug, time::Duration};
9
10#[cfg(feature = "parallel")]
11use anyhow::{Context, bail};
12#[cfg(feature = "parallel")]
13use itertools::Itertools;
14#[cfg(feature = "parallel")]
15use std::ops::Deref;
16
17#[cfg(feature = "parallel")]
18use rayon::{ScopeFifo, ThreadPoolBuilder};
19#[cfg(feature = "parallel")]
20use std::{
21 sync::{
22 Arc, Mutex,
23 mpsc::{
24 Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
25 },
26 },
27 thread::{JoinHandle, spawn},
28 time::Instant,
29};
30
31use crate::{
32 DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
33 adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
34 chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
35 dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
36 external_adapt_strategy::{ExternalTransformAdaptation, FlowSettings},
37 mclmc::MclmcTrajectoryKind,
38 nuts::NutsOptions,
39 sampler_stats::{SamplerStats, StatsDims},
40 transform::{
41 DiagAdaptStrategy, DiagMassMatrix, ExternalTransformation, LowRankMassMatrix,
42 LowRankMassMatrixStrategy, LowRankSettings,
43 },
44};
45
46#[cfg(feature = "parallel")]
47use crate::{model::Model, storage::{ChainStorage, StorageConfig, TraceStorage}};
48
49pub trait Settings:
51 private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + DeserializeOwned + 'static
52{
53 type Chain<M: Math>: Chain<M>;
54
55 fn new_chain<M: Math, R: Rng + ?Sized>(
56 &self,
57 chain: u64,
58 math: M,
59 rng: &mut R,
60 ) -> Self::Chain<M>;
61
62 fn hint_num_tune(&self) -> usize;
63 fn hint_num_draws(&self) -> usize;
64 fn num_chains(&self) -> usize;
65 fn seed(&self) -> u64;
66 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
67 fn sampler_name(&self) -> &'static str;
68 fn adaptation_name(&self) -> &'static str;
69
70 fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
71 let dims = StatsDims::from(math);
72 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
73 .into_iter()
74 .map(String::from)
75 .collect()
76 }
77
78 fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
79 <M::ExpandedVector as Storable<_>>::names(math)
80 .into_iter()
81 .map(String::from)
82 .collect()
83 }
84
85 fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
86 self.stat_names(math)
87 .into_iter()
88 .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
89 .collect()
90 }
91
92 fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
93 let dims = StatsDims::from(math);
94 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
95 }
96
97 fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
98 self.data_names(math)
99 .into_iter()
100 .map(|name| (name.clone(), self.data_type(math, &name)))
101 .collect()
102 }
103 fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
104 <M::ExpandedVector as Storable<_>>::item_type(math, name)
105 }
106
107 fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
108 self.stat_names(math)
109 .into_iter()
110 .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
111 .collect()
112 }
113
114 fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
115 let dims = StatsDims::from(math);
116 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
117 .into_iter()
118 .map(String::from)
119 .collect()
120 }
121
122 fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
123 let dims = StatsDims::from(math);
124 dims.dim_sizes()
125 }
126
127 fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
128 self.data_names(math)
129 .into_iter()
130 .map(|name| (name.clone(), self.data_dims(math, &name)))
131 .collect()
132 }
133
134 fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
135 <M::ExpandedVector as Storable<_>>::dims(math, name)
136 .into_iter()
137 .map(String::from)
138 .collect()
139 }
140
141 fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
142 let dims = StatsDims::from(math);
143 dims.coords()
144 }
145
146 fn stat_event_dims<M: Math>(&self, math: &M) -> Vec<(String, Option<String>)> {
147 let dims = StatsDims::from(math);
148 self.stat_names(math)
149 .into_iter()
150 .map(|name| {
151 let event_dim =
152 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::event_dim(
153 &dims, &name,
154 )
155 .map(String::from);
156 (name, event_dim)
157 })
158 .collect()
159 }
160}
161
162#[derive(Debug, Clone)]
163#[non_exhaustive]
164pub struct Progress {
165 pub draw: u64,
166 pub chain: u64,
167 pub diverging: bool,
168 pub tuning: bool,
169 pub step_size: f64,
170 pub num_steps: u64,
171}
172
173mod private {
174 use super::{
175 DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, FlowNutsSettings,
176 LowRankMclmcSettings, LowRankNutsSettings,
177 };
178
179 pub trait Sealed {}
180
181 impl Sealed for DiagNutsSettings {}
182
183 impl Sealed for LowRankNutsSettings {}
184
185 impl Sealed for FlowNutsSettings {}
186
187 impl Sealed for DiagMclmcSettings {}
188
189 impl Sealed for LowRankMclmcSettings {}
190
191 impl Sealed for FlowMclmcSettings {}
192}
193
194#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
196pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
197 pub num_tune: u64,
199 pub num_draws: u64,
201 pub maxdepth: u64,
204 pub mindepth: u64,
207 pub store_gradient: bool,
209 pub store_unconstrained: bool,
211 pub store_transformed: bool,
213 pub max_energy_error: f64,
216 pub store_divergences: bool,
218 pub adapt_options: A,
220 pub check_turning: bool,
221 pub target_integration_time: Option<f64>,
222 pub trajectory_kind: KineticEnergyKind,
230 pub num_chains: usize,
231 pub seed: u64,
232 pub extra_doublings: u64,
236}
237
238pub type DiagNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
239#[deprecated(since = "0.0.0", note = "Use DiagNutsSettings instead")]
241pub type DiagGradNutsSettings = DiagNutsSettings;
242pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
243pub type FlowNutsSettings = NutsSettings<FlowSettings>;
244#[deprecated(since = "0.0.0", note = "Use FlowNutsSettings instead")]
246pub type TransformedNutsSettings = FlowNutsSettings;
247
248#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
264pub struct MclmcSettings<A: Debug + Copy + Default + Serialize> {
265 pub step_size: f64,
267 pub momentum_decoherence_length: f64,
270 pub num_tune: u64,
272 pub num_draws: u64,
274 pub num_chains: usize,
276 pub seed: u64,
278 pub max_energy_error: f64,
280 pub store_unconstrained: bool,
282 pub store_gradient: bool,
284 pub store_transformed: bool,
286 pub store_divergences: bool,
288 pub adapt_options: A,
290 pub subsample_frequency: f64,
300 pub dynamic_step_size: bool,
306 pub trajectory_kind: MclmcTrajectoryKind,
310 pub trajectory_switch_fraction: f64,
315}
316
317pub type DiagMclmcSettings = MclmcSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
322pub type LowRankMclmcSettings = MclmcSettings<EuclideanAdaptOptions<LowRankSettings>>;
327pub type FlowMclmcSettings = MclmcSettings<FlowSettings>;
332#[deprecated(since = "0.0.0", note = "Use FlowMclmcSettings instead")]
334pub type TransformedMclmcSettings = FlowMclmcSettings;
335
336fn usize_hint(value: u64, field: &str) -> usize {
337 value
338 .try_into()
339 .unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
340}
341
342fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
343 adapt_options: A,
344 num_tune: u64,
345 num_chains: usize,
346 max_energy_error: f64,
347) -> MclmcSettings<A> {
348 MclmcSettings {
349 step_size: 0.5,
350 momentum_decoherence_length: 3.0,
351 num_tune,
352 num_draws: 1000,
353 num_chains,
354 seed: 0,
355 max_energy_error,
356 store_unconstrained: false,
357 store_gradient: false,
358 store_divergences: false,
359 store_transformed: false,
360 adapt_options,
361 subsample_frequency: 1.0,
362 dynamic_step_size: true,
363 trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
364 trajectory_switch_fraction: 0.3,
365 }
366}
367
368impl Default for DiagMclmcSettings {
369 fn default() -> Self {
370 let mut adapt_options = EuclideanAdaptOptions::default();
371 adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
372 default_mclmc_settings(adapt_options, 400, 6, 1000.0)
373 }
374}
375
376impl Default for LowRankMclmcSettings {
377 fn default() -> Self {
378 let mut adapt_options = EuclideanAdaptOptions::default();
379 adapt_options.early_mass_matrix_switch_freq = 20;
380 adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
381 default_mclmc_settings(adapt_options, 800, 6, 1000.0)
382 }
383}
384
385impl Default for FlowMclmcSettings {
386 fn default() -> Self {
387 default_mclmc_settings(FlowSettings::default(), 1500, 1, 20.0)
388 }
389}
390
391type DiagMclmcChain<M> = crate::mclmc::MclmcChain<
392 M,
393 ChaCha8Rng,
394 GlobalStrategy<M, DiagAdaptStrategy<M>>,
395 DiagMassMatrix<M>,
396>;
397type LowRankMclmcChain<M> = crate::mclmc::MclmcChain<
398 M,
399 ChaCha8Rng,
400 GlobalStrategy<M, LowRankMassMatrixStrategy>,
401 LowRankMassMatrix<M>,
402>;
403
404impl Settings for DiagMclmcSettings {
405 type Chain<M: Math> = DiagMclmcChain<M>;
406
407 fn new_chain<M: Math, R: Rng + ?Sized>(
408 &self,
409 chain: u64,
410 mut math: M,
411 rng: &mut R,
412 ) -> Self::Chain<M> {
413 use crate::dynamics::KineticEnergyKind;
414 use crate::mclmc::MclmcChain;
415 use crate::stepsize::StepSizeAdaptMethod;
416
417 let num_tune = self.num_tune;
418 let mut adapt_options = self.adapt_options;
419 adapt_options.step_size_settings.adapt_options.method =
420 StepSizeAdaptMethod::Fixed(self.step_size);
421 let strategy = GlobalStrategy::<M, DiagAdaptStrategy<M>>::new(
422 &mut math,
423 adapt_options,
424 num_tune,
425 chain,
426 );
427 let mass_matrix = DiagMassMatrix::new(
428 &mut math,
429 self.adapt_options.mass_matrix_options.store_mass_matrix,
430 );
431 let initial_kind = match self.trajectory_kind {
432 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
433 MclmcTrajectoryKind::Euclidean
434 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
435 };
436 let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
437 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
438 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
439 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
440 let stats_options = self.stats_options::<M>();
441 MclmcChain::new(
442 math,
443 hamiltonian,
444 strategy,
445 rng,
446 chain,
447 self.subsample_frequency,
448 self.dynamic_step_size,
449 self.trajectory_kind,
450 switch_draw,
451 self.max_energy_error,
452 stats_options,
453 )
454 }
455
456 fn hint_num_tune(&self) -> usize {
457 usize_hint(self.num_tune, "num_tune")
458 }
459
460 fn hint_num_draws(&self) -> usize {
461 usize_hint(self.num_draws, "num_draws")
462 }
463
464 fn num_chains(&self) -> usize {
465 self.num_chains
466 }
467
468 fn seed(&self) -> u64 {
469 self.seed
470 }
471
472 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
473 StatOptions {
474 adapt: GlobalStrategyStatsOptions {
475 step_size: (),
476 mass_matrix: (),
477 },
478 hamiltonian: -1,
479 point: {
480 let store_gradient = self.store_gradient;
481 let store_unconstrained = self.store_unconstrained;
482 let store_transformed = self.store_transformed;
483 TransformedPointStatsOptions {
484 store_gradient,
485 store_unconstrained,
486 store_transformed,
487 }
488 },
489 divergence: crate::dynamics::DivergenceStatsOptions {
490 store_divergences: self.store_divergences,
491 },
492 }
493 }
494
495 fn sampler_name(&self) -> &'static str {
496 "mclmc"
497 }
498
499 fn adaptation_name(&self) -> &'static str {
500 "diagonal"
501 }
502}
503
504fn default_nuts_settings<A: Debug + Copy + Default + Serialize>(
505 adapt_options: A,
506 num_tune: u64,
507 num_chains: usize,
508 max_energy_error: f64,
509) -> NutsSettings<A> {
510 NutsSettings {
511 num_tune,
512 num_draws: 1000,
513 maxdepth: 10,
514 mindepth: 0,
515 max_energy_error,
516 store_gradient: false,
517 store_unconstrained: false,
518 store_transformed: false,
519 store_divergences: false,
520 adapt_options,
521 check_turning: true,
522 seed: 0,
523 num_chains,
524 target_integration_time: None,
525 trajectory_kind: KineticEnergyKind::Euclidean,
526 extra_doublings: 0,
527 }
528}
529
530impl Settings for LowRankMclmcSettings {
531 type Chain<M: Math> = LowRankMclmcChain<M>;
532
533 fn new_chain<M: Math, R: Rng + ?Sized>(
534 &self,
535 chain: u64,
536 mut math: M,
537 rng: &mut R,
538 ) -> Self::Chain<M> {
539 use crate::dynamics::KineticEnergyKind;
540 use crate::mclmc::MclmcChain;
541 use crate::stepsize::StepSizeAdaptMethod;
542
543 let num_tune = self.num_tune;
544 let mut adapt_options = self.adapt_options;
545 adapt_options.step_size_settings.adapt_options.method =
546 StepSizeAdaptMethod::Fixed(self.step_size);
547 let strategy = GlobalStrategy::<M, LowRankMassMatrixStrategy>::new(
548 &mut math,
549 adapt_options,
550 num_tune,
551 chain,
552 );
553 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
554 let initial_kind = match self.trajectory_kind {
555 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
556 MclmcTrajectoryKind::Euclidean
557 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
558 };
559 let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
560 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
561 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
562 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
563 let stats_options = self.stats_options::<M>();
564 MclmcChain::new(
565 math,
566 hamiltonian,
567 strategy,
568 rng,
569 chain,
570 self.subsample_frequency,
571 self.dynamic_step_size,
572 self.trajectory_kind,
573 switch_draw,
574 self.max_energy_error,
575 stats_options,
576 )
577 }
578
579 fn hint_num_tune(&self) -> usize {
580 usize_hint(self.num_tune, "num_tune")
581 }
582
583 fn hint_num_draws(&self) -> usize {
584 usize_hint(self.num_draws, "num_draws")
585 }
586
587 fn num_chains(&self) -> usize {
588 self.num_chains
589 }
590
591 fn seed(&self) -> u64 {
592 self.seed
593 }
594
595 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
596 StatOptions {
597 adapt: GlobalStrategyStatsOptions {
598 step_size: (),
599 mass_matrix: (),
600 },
601 hamiltonian: -1,
602 point: {
603 let store_gradient = self.store_gradient;
604 let store_unconstrained = self.store_unconstrained;
605 let store_transformed = self.store_transformed;
606 TransformedPointStatsOptions {
607 store_gradient,
608 store_unconstrained,
609 store_transformed,
610 }
611 },
612 divergence: crate::dynamics::DivergenceStatsOptions {
613 store_divergences: self.store_divergences,
614 },
615 }
616 }
617
618 fn sampler_name(&self) -> &'static str {
619 "mclmc"
620 }
621
622 fn adaptation_name(&self) -> &'static str {
623 "low_rank"
624 }
625}
626
627impl Default for DiagNutsSettings {
628 fn default() -> Self {
629 default_nuts_settings(EuclideanAdaptOptions::default(), 400, 6, 1000.0)
630 }
631}
632
633impl Default for LowRankNutsSettings {
634 fn default() -> Self {
635 let mut vals = default_nuts_settings(EuclideanAdaptOptions::default(), 800, 6, 1000.0);
636 vals.adapt_options.mass_matrix_update_freq = 20;
637 vals
638 }
639}
640
641impl Default for FlowNutsSettings {
642 fn default() -> Self {
643 default_nuts_settings(FlowSettings::default(), 1500, 1, 20.0)
644 }
645}
646
647type DiagNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, DiagAdaptStrategy<M>>>;
648type LowRankNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
649
650fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>) -> NutsOptions {
651 NutsOptions {
652 maxdepth: settings.maxdepth,
653 mindepth: settings.mindepth,
654 store_divergences: settings.store_divergences,
655 check_turning: settings.check_turning,
656 target_integration_time: settings.target_integration_time,
657 extra_doublings: settings.extra_doublings,
658 max_energy_error: settings.max_energy_error,
659 }
660}
661
662impl Settings for LowRankNutsSettings {
663 type Chain<M: Math> = LowRankNutsChain<M>;
664
665 fn new_chain<M: Math, R: Rng + ?Sized>(
666 &self,
667 chain: u64,
668 mut math: M,
669 mut rng: &mut R,
670 ) -> Self::Chain<M> {
671 let num_tune = self.num_tune;
672 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
673 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
674 let hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
675
676 let options = nuts_options(self);
677
678 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
679
680 NutsChain::new(
681 math,
682 hamiltonian,
683 strategy,
684 options,
685 rng,
686 chain,
687 self.stats_options(),
688 )
689 }
690
691 fn hint_num_tune(&self) -> usize {
692 usize_hint(self.num_tune, "num_tune")
693 }
694
695 fn hint_num_draws(&self) -> usize {
696 usize_hint(self.num_draws, "num_draws")
697 }
698
699 fn num_chains(&self) -> usize {
700 self.num_chains
701 }
702
703 fn seed(&self) -> u64 {
704 self.seed
705 }
706
707 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
708 StatOptions {
709 adapt: GlobalStrategyStatsOptions {
710 mass_matrix: (),
711 step_size: (),
712 },
713 hamiltonian: -1,
714 point: {
715 let store_gradient = self.store_gradient;
716 let store_unconstrained = self.store_unconstrained;
717 let store_transformed = self.store_transformed;
718 TransformedPointStatsOptions {
719 store_gradient,
720 store_unconstrained,
721 store_transformed,
722 }
723 },
724 divergence: crate::dynamics::DivergenceStatsOptions {
725 store_divergences: self.store_divergences,
726 },
727 }
728 }
729
730 fn sampler_name(&self) -> &'static str {
731 "nuts"
732 }
733
734 fn adaptation_name(&self) -> &'static str {
735 "low_rank"
736 }
737}
738
739impl Settings for DiagNutsSettings {
740 type Chain<M: Math> = DiagNutsChain<M>;
741
742 fn new_chain<M: Math, R: Rng + ?Sized>(
743 &self,
744 chain: u64,
745 mut math: M,
746 mut rng: &mut R,
747 ) -> Self::Chain<M> {
748 let num_tune = self.num_tune;
749 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
750 let mass_matrix = DiagMassMatrix::new(
751 &mut math,
752 self.adapt_options.mass_matrix_options.store_mass_matrix,
753 );
754 let potential = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
755
756 let options = nuts_options(self);
757
758 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
759
760 NutsChain::new(
761 math,
762 potential,
763 strategy,
764 options,
765 rng,
766 chain,
767 self.stats_options(),
768 )
769 }
770
771 fn hint_num_tune(&self) -> usize {
772 usize_hint(self.num_tune, "num_tune")
773 }
774
775 fn hint_num_draws(&self) -> usize {
776 usize_hint(self.num_draws, "num_draws")
777 }
778
779 fn num_chains(&self) -> usize {
780 self.num_chains
781 }
782
783 fn seed(&self) -> u64 {
784 self.seed
785 }
786
787 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
788 StatOptions {
789 adapt: GlobalStrategyStatsOptions {
790 mass_matrix: (),
791 step_size: (),
792 },
793 hamiltonian: -1,
794 point: {
795 let store_gradient = self.store_gradient;
796 let store_unconstrained = self.store_unconstrained;
797 let store_transformed = self.store_transformed;
798 TransformedPointStatsOptions {
799 store_gradient,
800 store_unconstrained,
801 store_transformed,
802 }
803 },
804 divergence: crate::dynamics::DivergenceStatsOptions {
805 store_divergences: self.store_divergences,
806 },
807 }
808 }
809
810 fn sampler_name(&self) -> &'static str {
811 "nuts"
812 }
813
814 fn adaptation_name(&self) -> &'static str {
815 "diagonal"
816 }
817}
818
819impl Settings for FlowNutsSettings {
820 type Chain<M: Math> = NutsChain<M, ChaCha8Rng, ExternalTransformAdaptation>;
821
822 fn new_chain<M: Math, R: Rng + ?Sized>(
823 &self,
824 chain: u64,
825 mut math: M,
826 mut rng: &mut R,
827 ) -> Self::Chain<M> {
828 let num_tune = self.num_tune;
829
830 let strategy =
831 ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
832 let params = math
833 .new_transformation(rng, math.dim(), chain)
834 .expect("Failed to create external transformation");
835 let transform = ExternalTransformation::new(params);
836 let hamiltonian = TransformedHamiltonian::new(&mut math, transform, self.trajectory_kind);
837
838 let options = nuts_options(self);
839
840 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
841 NutsChain::new(
842 math,
843 hamiltonian,
844 strategy,
845 options,
846 rng,
847 chain,
848 self.stats_options(),
849 )
850 }
851
852 fn hint_num_tune(&self) -> usize {
853 usize_hint(self.num_tune, "num_tune")
854 }
855
856 fn hint_num_draws(&self) -> usize {
857 usize_hint(self.num_draws, "num_draws")
858 }
859
860 fn num_chains(&self) -> usize {
861 self.num_chains
862 }
863
864 fn seed(&self) -> u64 {
865 self.seed
866 }
867
868 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
869 StatOptions {
870 adapt: (),
871 hamiltonian: (),
872 point: {
873 let store_gradient = self.store_gradient;
874 let store_unconstrained = self.store_unconstrained;
875 let store_transformed = self.store_transformed;
876 TransformedPointStatsOptions {
877 store_gradient,
878 store_unconstrained,
879 store_transformed,
880 }
881 },
882 divergence: crate::dynamics::DivergenceStatsOptions {
883 store_divergences: self.store_divergences,
884 },
885 }
886 }
887
888 fn sampler_name(&self) -> &'static str {
889 "nuts"
890 }
891
892 fn adaptation_name(&self) -> &'static str {
893 "flow"
894 }
895}
896
897impl Settings for FlowMclmcSettings {
898 type Chain<M: Math> = crate::mclmc::MclmcChain<
899 M,
900 ChaCha8Rng,
901 ExternalTransformAdaptation,
902 ExternalTransformation<M>,
903 >;
904
905 fn new_chain<M: Math, R: Rng + ?Sized>(
906 &self,
907 chain: u64,
908 mut math: M,
909 rng: &mut R,
910 ) -> Self::Chain<M> {
911 use crate::dynamics::KineticEnergyKind;
912 use crate::mclmc::MclmcChain;
913
914 let num_tune = self.num_tune;
915 let strategy =
916 ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
917 let params = math
918 .new_transformation(rng, math.dim(), chain)
919 .expect("Failed to create external transformation");
920 let transform = ExternalTransformation::new(params);
921 let initial_kind = match self.trajectory_kind {
922 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
923 MclmcTrajectoryKind::Euclidean
924 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
925 };
926 let mut hamiltonian = TransformedHamiltonian::new(&mut math, transform, initial_kind);
927 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
928 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
929 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
930 let stats_options = self.stats_options::<M>();
931 MclmcChain::new(
932 math,
933 hamiltonian,
934 strategy,
935 rng,
936 chain,
937 self.subsample_frequency,
938 self.dynamic_step_size,
939 self.trajectory_kind,
940 switch_draw,
941 self.max_energy_error,
942 stats_options,
943 )
944 }
945
946 fn hint_num_tune(&self) -> usize {
947 usize_hint(self.num_tune, "num_tune")
948 }
949
950 fn hint_num_draws(&self) -> usize {
951 usize_hint(self.num_draws, "num_draws")
952 }
953
954 fn num_chains(&self) -> usize {
955 self.num_chains
956 }
957
958 fn seed(&self) -> u64 {
959 self.seed
960 }
961
962 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
963 StatOptions {
964 adapt: (),
965 hamiltonian: (),
966 point: {
967 let store_gradient = self.store_gradient;
968 let store_unconstrained = self.store_unconstrained;
969 let store_transformed = self.store_transformed;
970 TransformedPointStatsOptions {
971 store_gradient,
972 store_unconstrained,
973 store_transformed,
974 }
975 },
976 divergence: crate::dynamics::DivergenceStatsOptions {
977 store_divergences: self.store_divergences,
978 },
979 }
980 }
981
982 fn sampler_name(&self) -> &'static str {
983 "mclmc"
984 }
985
986 fn adaptation_name(&self) -> &'static str {
987 "flow"
988 }
989}
990
991pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
992 math: M,
993 settings: DiagNutsSettings,
994 start: &[f64],
995 draws: u64,
996 chain: u64,
997 rng: &mut R,
998) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
999 let mut sampler = settings.new_chain(chain, math, rng);
1000 sampler.set_position(start)?;
1001 Ok((0..draws).map(move |_| sampler.draw()))
1002}
1003
1004#[non_exhaustive]
1005#[derive(Clone, Debug)]
1006pub struct ChainProgress {
1007 pub finished_draws: usize,
1008 pub total_draws: usize,
1009 pub divergences: usize,
1010 pub tuning: bool,
1011 pub started: bool,
1012 pub latest_num_steps: usize,
1013 pub total_num_steps: usize,
1014 pub step_size: f64,
1015 pub runtime: Duration,
1016 pub divergent_draws: Vec<usize>,
1017}
1018
1019impl ChainProgress {
1020 fn new(total: usize) -> Self {
1021 Self {
1022 finished_draws: 0,
1023 total_draws: total,
1024 divergences: 0,
1025 tuning: true,
1026 started: false,
1027 latest_num_steps: 0,
1028 step_size: 0f64,
1029 total_num_steps: 0,
1030 runtime: Duration::ZERO,
1031 divergent_draws: Vec::new(),
1032 }
1033 }
1034
1035 fn update(&mut self, stats: &Progress, draw_duration: Duration) {
1036 if stats.diverging & !stats.tuning {
1037 self.divergences += 1;
1038 self.divergent_draws.push(self.finished_draws);
1039 }
1040 self.finished_draws += 1;
1041 self.tuning = stats.tuning;
1042
1043 self.latest_num_steps = stats.num_steps as usize;
1044 self.total_num_steps += stats.num_steps as usize;
1045 self.step_size = stats.step_size;
1046 self.runtime += draw_duration;
1047 }
1048}
1049
1050#[cfg(feature = "parallel")]
1051enum ChainCommand {
1052 Resume,
1053 Pause,
1054}
1055
1056#[cfg(feature = "parallel")]
1057struct ChainProcess<T>
1058where
1059 T: TraceStorage,
1060{
1061 stop_marker: Sender<ChainCommand>,
1062 trace: Arc<Mutex<Option<T::ChainStorage>>>,
1063 progress: Arc<Mutex<ChainProgress>>,
1064}
1065
1066#[cfg(feature = "parallel")]
1067impl<T: TraceStorage> ChainProcess<T> {
1068 fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
1069 let finalized_chain_traces = chains
1070 .into_iter()
1071 .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
1072 .map(|chain| chain.finalize())
1073 .collect_vec();
1074 trace.finalize(finalized_chain_traces)
1075 }
1076
1077 fn progress(&self) -> ChainProgress {
1078 self.progress.lock().expect("Poisoned lock").clone()
1079 }
1080
1081 fn resume(&self) -> Result<()> {
1082 self.stop_marker.send(ChainCommand::Resume)?;
1083 Ok(())
1084 }
1085
1086 fn pause(&self) -> Result<()> {
1087 self.stop_marker.send(ChainCommand::Pause)?;
1088 Ok(())
1089 }
1090
1091 fn start<'model, M: Model, S: Settings>(
1092 model: &'model M,
1093 chain_trace: T::ChainStorage,
1094 chain_id: u64,
1095 seed: u64,
1096 settings: &'model S,
1097 scope: &ScopeFifo<'model>,
1098 results: Sender<Result<()>>,
1099 ) -> Result<Self> {
1100 let (stop_marker_tx, stop_marker_rx) = channel();
1101
1102 let mut rng = ChaCha8Rng::seed_from_u64(seed);
1103 rng.set_stream(chain_id + 1);
1104
1105 let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
1106 let progress = Arc::new(Mutex::new(ChainProgress::new(
1107 settings.hint_num_draws() + settings.hint_num_tune(),
1108 )));
1109
1110 let trace_inner = chain_trace.clone();
1111 let progress_inner = progress.clone();
1112
1113 scope.spawn_fifo(move |_| {
1114 let chain_trace = trace_inner;
1115 let progress = progress_inner;
1116
1117 let mut sample = move || {
1118 let logp = model
1119 .math(&mut rng)
1120 .context("Failed to create model density")?;
1121 let dim = logp.dim();
1122
1123 let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
1124
1125 progress.lock().expect("Poisoned mutex").started = true;
1126
1127 let mut initval = vec![0f64; dim];
1128 let mut error = None;
1130 for _ in 0..500 {
1131 model
1132 .init_position(&mut rng, &mut initval)
1133 .context("Failed to generate a new initial position")?;
1134 if let Err(err) = sampler.set_position(&initval) {
1135 error = Some(err);
1136 continue;
1137 }
1138 error = None;
1139 break;
1140 }
1141
1142 if let Some(error) = error {
1143 return Err(error.context("All initialization points failed"));
1144 }
1145
1146 let draws = settings.hint_num_tune() + settings.hint_num_draws();
1147
1148 let mut msg = stop_marker_rx.try_recv();
1149 let mut draw = 0;
1150 loop {
1151 match msg {
1152 Err(TryRecvError::Disconnected) => {
1154 break;
1155 }
1156 Err(TryRecvError::Empty) => {}
1157 Ok(ChainCommand::Pause) => {
1158 msg = stop_marker_rx.recv().map_err(|e| e.into());
1159 continue;
1160 }
1161 Ok(ChainCommand::Resume) => {}
1162 }
1163
1164 let now = Instant::now();
1165 let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
1166
1167 let mut guard = chain_trace
1168 .lock()
1169 .expect("Could not unlock trace lock. Poisoned mutex");
1170
1171 let Some(trace_val) = guard.as_mut() else {
1172 break;
1174 };
1175 progress
1176 .lock()
1177 .expect("Poisoned mutex")
1178 .update(&info, now.elapsed());
1179
1180 let math = sampler.math();
1181 let dims = StatsDims::from(math.deref());
1182 trace_val.record_sample(
1183 settings,
1184 stats.get_all(&dims),
1185 draw_data.get_all(math.deref()),
1186 &info,
1187 )?;
1188
1189 draw += 1;
1190 if draw == draws {
1191 break;
1192 }
1193
1194 msg = stop_marker_rx.try_recv();
1195 }
1196 Ok(())
1197 };
1198
1199 let result = sample();
1200
1201 let _ = results.send(result);
1204 drop(results);
1205 });
1206
1207 Ok(Self {
1208 trace: chain_trace,
1209 stop_marker: stop_marker_tx,
1210 progress,
1211 })
1212 }
1213
1214 fn flush(&self) -> Result<()> {
1215 self.trace
1216 .lock()
1217 .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
1218 .context("Could not flush trace")?
1219 .as_mut()
1220 .map(|v| v.flush())
1221 .transpose()?;
1222 Ok(())
1223 }
1224}
1225
1226#[cfg(feature = "parallel")]
1227#[derive(Debug)]
1228enum SamplerCommand {
1229 Pause,
1230 Continue,
1231 Progress,
1232 Flush,
1233 Inspect,
1234}
1235
1236#[cfg(feature = "parallel")]
1237enum SamplerResponse<T: Send + 'static> {
1238 Ok(),
1239 Progress(Box<[ChainProgress]>),
1240 Inspect(T),
1241}
1242
1243#[cfg(feature = "parallel")]
1244pub enum SamplerWaitResult<F: Send + 'static> {
1245 Trace(F),
1246 Timeout(Sampler<F>),
1247 Err(anyhow::Error, Option<F>),
1248}
1249
1250#[cfg(feature = "parallel")]
1251pub struct Sampler<F: Send + 'static> {
1252 main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
1253 commands: SyncSender<SamplerCommand>,
1254 responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
1255 results: Receiver<Result<()>>,
1256}
1257
1258#[cfg(feature = "parallel")]
1259pub struct ProgressCallback {
1260 pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
1261 pub rate: Duration,
1262}
1263
1264#[cfg(feature = "parallel")]
1265impl<F: Send + 'static> Sampler<F> {
1266 pub fn new<M, S, C, T>(
1267 model: M,
1268 settings: S,
1269 trace_config: C,
1270 num_cores: usize,
1271 callback: Option<ProgressCallback>,
1272 ) -> Result<Self>
1273 where
1274 S: Settings,
1275 C: StorageConfig<Storage = T>,
1276 M: Model,
1277 T: TraceStorage<Finalized = F>,
1278 {
1279 let (commands_tx, commands_rx) = sync_channel(0);
1280 let (responses_tx, responses_rx) = sync_channel(0);
1281 let (results_tx, results_rx) = channel();
1282
1283 let main_thread = spawn(move || {
1284 let pool = ThreadPoolBuilder::new()
1285 .num_threads(num_cores + 1) .thread_name(|i| format!("nutpie-worker-{i}"))
1287 .build()
1288 .context("Could not start thread pool")?;
1289
1290 let settings_ref = &settings;
1291 let model_ref = &model;
1292 let mut callback = callback;
1293
1294 pool.scope_fifo(move |scope| {
1295 let results = results_tx;
1296 let mut chains = Vec::with_capacity(settings.num_chains());
1297
1298 let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
1299 rng.set_stream(0);
1300
1301 let math = model_ref
1302 .math(&mut rng)
1303 .context("Could not create model density")?;
1304 let trace = trace_config
1305 .new_trace(settings_ref, &math)
1306 .context("Could not create trace object")?;
1307 drop(math);
1308
1309 for chain_id in 0..settings.num_chains() {
1310 let chain_trace_val = trace
1311 .initialize_trace_for_chain(chain_id as u64)
1312 .context("Failed to create trace object")?;
1313 let chain = ChainProcess::start(
1314 model_ref,
1315 chain_trace_val,
1316 chain_id as u64,
1317 settings.seed(),
1318 settings_ref,
1319 scope,
1320 results.clone(),
1321 );
1322 chains.push(chain);
1323 }
1324 drop(results);
1325
1326 let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
1327 if let Some(error) = errors.into_iter().next() {
1328 let _ = ChainProcess::finalize_many(trace, chains);
1329 return Err(error).context("Could not start chains");
1330 }
1331
1332 let mut main_loop = || {
1333 let start_time = Instant::now();
1334 let mut pause_start = Instant::now();
1335 let mut pause_time = Duration::ZERO;
1336
1337 let mut progress_rate = Duration::MAX;
1338 if let Some(ProgressCallback { callback, rate }) = &mut callback {
1339 let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
1340 callback(start_time.elapsed(), progress.into());
1341 progress_rate = *rate;
1342 }
1343 let mut last_progress = Instant::now();
1344 let mut is_paused = false;
1345
1346 loop {
1347 let timeout = progress_rate.checked_sub(last_progress.elapsed());
1348 let timeout = timeout.unwrap_or_else(|| {
1349 if let Some(ProgressCallback { callback, .. }) = &mut callback {
1350 let progress =
1351 chains.iter().map(|chain| chain.progress()).collect_vec();
1352 let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
1353 if is_paused {
1354 elapsed = elapsed.saturating_sub(pause_start.elapsed());
1355 }
1356 callback(elapsed, progress.into());
1357 }
1358 last_progress = Instant::now();
1359 progress_rate
1360 });
1361
1362 match commands_rx.recv_timeout(timeout) {
1364 Ok(SamplerCommand::Pause) => {
1365 for chain in chains.iter() {
1366 let _ = chain.pause();
1369 }
1370 if !is_paused {
1371 pause_start = Instant::now();
1372 }
1373 is_paused = true;
1374 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1375 anyhow::anyhow!(
1376 "Could not send pause response to controller thread: {e}"
1377 )
1378 })?;
1379 }
1380 Ok(SamplerCommand::Continue) => {
1381 for chain in chains.iter() {
1382 let _ = chain.resume();
1385 }
1386 pause_time += pause_start.elapsed();
1387 is_paused = false;
1388 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1389 anyhow::anyhow!(
1390 "Could not send continue response to controller thread: {e}"
1391 )
1392 })?;
1393 }
1394 Ok(SamplerCommand::Progress) => {
1395 let progress =
1396 chains.iter().map(|chain| chain.progress()).collect_vec();
1397 responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
1398 anyhow::anyhow!(
1399 "Could not send progress response to controller thread: {e}"
1400 )
1401 })?;
1402 }
1403 Ok(SamplerCommand::Inspect) => {
1404 let traces = chains
1405 .iter()
1406 .filter_map(|chain| {
1407 chain
1408 .trace
1409 .lock()
1410 .expect("Poisoned lock")
1411 .as_ref()
1412 .map(|v| v.inspect())
1413 })
1414 .collect_vec();
1415 let finalized_trace = trace.inspect(traces)?;
1416 responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
1417 anyhow::anyhow!(
1418 "Could not send inspect response to controller thread: {e}"
1419 )
1420 })?;
1421 }
1422 Ok(SamplerCommand::Flush) => {
1423 for chain in chains.iter() {
1424 chain.flush()?;
1425 }
1426 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1427 anyhow::anyhow!(
1428 "Could not send flush response to controller thread: {e}"
1429 )
1430 })?;
1431 }
1432 Err(RecvTimeoutError::Timeout) => {}
1433 Err(RecvTimeoutError::Disconnected) => {
1434 if let Some(ProgressCallback { callback, .. }) = &mut callback {
1435 let progress =
1436 chains.iter().map(|chain| chain.progress()).collect_vec();
1437 let mut elapsed =
1438 start_time.elapsed().saturating_sub(pause_time);
1439 if is_paused {
1440 elapsed = elapsed.saturating_sub(pause_start.elapsed());
1441 }
1442 callback(elapsed, progress.into());
1443 }
1444 return Ok(());
1445 }
1446 };
1447 }
1448 };
1449 let result: Result<()> = main_loop();
1450 let output = ChainProcess::finalize_many(trace, chains)?;
1452
1453 result?;
1454 Ok(output)
1455 })
1456 });
1457
1458 Ok(Self {
1459 main_thread,
1460 commands: commands_tx,
1461 responses: responses_rx,
1462 results: results_rx,
1463 })
1464 }
1465
1466 pub fn pause(&mut self) -> Result<()> {
1467 self.commands
1468 .send(SamplerCommand::Pause)
1469 .context("Could not send pause command to controller thread")?;
1470 let response = self
1471 .responses
1472 .recv()
1473 .context("Could not recieve pause response from controller thread")?;
1474 let SamplerResponse::Ok() = response else {
1475 bail!("Got invalid response from sample controller thread");
1476 };
1477 Ok(())
1478 }
1479
1480 pub fn resume(&mut self) -> Result<()> {
1481 self.commands.send(SamplerCommand::Continue)?;
1482 let response = self.responses.recv()?;
1483 let SamplerResponse::Ok() = response else {
1484 bail!("Got invalid response from sample controller thread");
1485 };
1486 Ok(())
1487 }
1488
1489 pub fn flush(&mut self) -> Result<()> {
1490 self.commands.send(SamplerCommand::Flush)?;
1491 let response = self
1492 .responses
1493 .recv()
1494 .context("Could not recieve flush response from controller thread")?;
1495 let SamplerResponse::Ok() = response else {
1496 bail!("Got invalid response from sample controller thread");
1497 };
1498 Ok(())
1499 }
1500
1501 pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
1502 self.commands.send(SamplerCommand::Inspect)?;
1503 let response = self
1504 .responses
1505 .recv()
1506 .context("Could not recieve inspect response from controller thread")?;
1507 let SamplerResponse::Inspect(trace) = response else {
1508 bail!("Got invalid response from sample controller thread");
1509 };
1510 Ok(trace)
1511 }
1512
1513 pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
1514 drop(self.commands);
1515 let result = self.main_thread.join();
1516 match result {
1517 Err(payload) => std::panic::resume_unwind(payload),
1518 Ok(Ok(val)) => Ok(val),
1519 Ok(Err(err)) => Err(err),
1520 }
1521 }
1522
1523 pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
1524 let start = Instant::now();
1525 let mut remaining = Some(timeout);
1526 while remaining.is_some() {
1527 match self.results.recv_timeout(timeout) {
1528 Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
1529 Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
1530 Err(RecvTimeoutError::Disconnected) => match self.abort() {
1531 Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
1532 Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
1533 Err(err) => return SamplerWaitResult::Err(err, None),
1534 },
1535 Err(RecvTimeoutError::Timeout) => break,
1536 }
1537 }
1538 SamplerWaitResult::Timeout(self)
1539 }
1540
1541 pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1542 self.commands.send(SamplerCommand::Progress)?;
1543 let response = self.responses.recv()?;
1544 let SamplerResponse::Progress(progress) = response else {
1545 bail!("Got invalid response from sample controller thread");
1546 };
1547 Ok(progress)
1548 }
1549}
1550
1551#[cfg(test)]
1552pub mod test_logps {
1553 #[cfg(feature = "zarr")]
1554 use crate::{Model, math::CpuLogpFunc, math::CpuMath};
1555 #[cfg(feature = "zarr")]
1556 use anyhow::Result;
1557 #[cfg(feature = "zarr")]
1558 use rand::Rng;
1559
1560 #[cfg(feature = "zarr")]
1561 pub struct CpuModel<F> {
1562 logp: F,
1563 }
1564
1565 #[cfg(feature = "zarr")]
1566 impl<F> CpuModel<F> {
1567 pub fn new(logp: F) -> Self {
1568 Self { logp }
1569 }
1570 }
1571
1572 #[cfg(feature = "zarr")]
1573 impl<F> Model for CpuModel<F>
1574 where
1575 F: Send + Sync + 'static,
1576 for<'a> &'a F: CpuLogpFunc,
1577 {
1578 type Math<'model> = CpuMath<&'model F>;
1579
1580 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1581 Ok(CpuMath::new(&self.logp))
1582 }
1583
1584 fn init_position<R: rand::prelude::Rng + ?Sized>(
1585 &self,
1586 _rng: &mut R,
1587 position: &mut [f64],
1588 ) -> Result<()> {
1589 position.iter_mut().for_each(|x| *x = 0.);
1590 Ok(())
1591 }
1592 }
1593}
1594
1595#[cfg(test)]
1596mod tests {
1597 use crate::math::test_logps::NormalLogp;
1598 use crate::{
1599 Chain, math::CpuMath, sample_sequentially, sampler::DiagMclmcSettings,
1600 sampler::DiagNutsSettings, sampler::LowRankMclmcSettings, sampler::LowRankNutsSettings,
1601 sampler::Settings,
1602 };
1603
1604 #[cfg(feature = "zarr")]
1605 use super::test_logps::CpuModel;
1606
1607 use anyhow::Result;
1608 use itertools::Itertools;
1609 use pretty_assertions::assert_eq;
1610 use rand::{SeedableRng, rngs::StdRng};
1611
1612 #[cfg(feature = "zarr")]
1613 use std::{
1614 sync::Arc,
1615 time::{Duration, Instant},
1616 };
1617
1618 #[cfg(feature = "zarr")]
1619 use crate::{Sampler, ZarrConfig};
1620
1621 #[cfg(feature = "zarr")]
1622 use zarrs::storage::store::MemoryStore;
1623
1624 fn assert_settings_smoke<S: Settings>(settings: S) -> Result<()> {
1625 let logp = NormalLogp { dim: 4, mu: 0.1 };
1626 let math = CpuMath::new(&logp);
1627 let mut rng = StdRng::seed_from_u64(42);
1628
1629 let stat_names = settings.stat_names(&math);
1630 let stat_types = settings.stat_types(&math);
1631 assert!(!stat_names.is_empty());
1632 assert_eq!(stat_names.len(), stat_types.len());
1633
1634 let mut chain = settings.new_chain(0, math, &mut rng);
1635 chain.set_position(&vec![0.2; 4])?;
1636 let (_draw, _info) = chain.draw()?;
1637 Ok(())
1638 }
1639
1640 #[test]
1641 fn all_settings_smoke() -> Result<()> {
1642 assert_settings_smoke(DiagNutsSettings {
1643 num_tune: 10,
1644 num_draws: 10,
1645 ..Default::default()
1646 })?;
1647 assert_settings_smoke(LowRankNutsSettings {
1648 num_tune: 10,
1649 num_draws: 10,
1650 ..Default::default()
1651 })?;
1652 assert_settings_smoke(DiagMclmcSettings {
1653 num_tune: 10,
1654 num_draws: 10,
1655 ..Default::default()
1656 })?;
1657 assert_settings_smoke(LowRankMclmcSettings {
1658 num_tune: 10,
1659 num_draws: 10,
1660 ..Default::default()
1661 })?;
1662 Ok(())
1663 }
1664
1665 #[test]
1666 fn sample_chain() -> Result<()> {
1667 let logp = NormalLogp { dim: 10, mu: 0.1 };
1668 let math = CpuMath::new(&logp);
1669 let settings = DiagNutsSettings {
1670 num_tune: 100,
1671 num_draws: 100,
1672 ..Default::default()
1673 };
1674 let start = vec![0.2; 10];
1675
1676 let mut rng = StdRng::seed_from_u64(42);
1677
1678 let mut chain = settings.new_chain(0, math, &mut rng);
1679
1680 let (_draw, info) = chain.draw()?;
1681 assert!(info.tuning);
1682 assert_eq!(info.draw, 0);
1683
1684 let math = CpuMath::new(&logp);
1685 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1686 let mut draws = chain.collect_vec();
1687 assert_eq!(draws.len(), 200);
1688
1689 let draw0 = draws.remove(100).unwrap();
1690 let (vals, stats) = draw0;
1691 assert_eq!(vals.len(), 10);
1692 assert_eq!(stats.chain, 1);
1693 assert_eq!(stats.draw, 100);
1694 Ok(())
1695 }
1696
1697 #[cfg(feature = "zarr")]
1698 #[test]
1699 fn sample_parallel() -> Result<()> {
1700 let logp = NormalLogp { dim: 100, mu: 0.1 };
1701 let settings = DiagNutsSettings {
1702 num_tune: 100,
1703 num_draws: 100,
1704 seed: 10,
1705 ..Default::default()
1706 };
1707
1708 let model = CpuModel::new(logp.clone());
1709 let store = MemoryStore::new();
1710
1711 let zarr_config = ZarrConfig::new(Arc::new(store));
1712 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1713 sampler.pause()?;
1714 sampler.pause()?;
1715 sampler.resume()?;
1717 let (ok, _) = sampler.abort()?;
1718 if let Some(err) = ok {
1719 Err(err)?;
1720 }
1721
1722 let store = MemoryStore::new();
1723 let zarr_config = ZarrConfig::new(Arc::new(store));
1724 let model = CpuModel::new(logp.clone());
1725 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1726 sampler.pause()?;
1727 if let (Some(err), _) = sampler.abort()? {
1728 Err(err)?;
1729 }
1730
1731 let store = MemoryStore::new();
1732 let zarr_config = ZarrConfig::new(Arc::new(store));
1733 let model = CpuModel::new(logp.clone());
1734 let start = Instant::now();
1735 let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1736
1737 let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1738 super::SamplerWaitResult::Trace(_) => {
1739 dbg!(start.elapsed());
1740 panic!("finished");
1741 }
1742 super::SamplerWaitResult::Timeout(sampler) => sampler,
1743 super::SamplerWaitResult::Err(_, _) => {
1744 panic!("error")
1745 }
1746 };
1747
1748 for _ in 0..30 {
1749 sampler.progress()?;
1750 }
1751
1752 match sampler.wait_timeout(Duration::from_secs(1)) {
1753 super::SamplerWaitResult::Trace(_) => {
1754 dbg!(start.elapsed());
1755 }
1756 super::SamplerWaitResult::Timeout(_) => {
1757 panic!("timeout")
1758 }
1759 super::SamplerWaitResult::Err(err, _) => Err(err)?,
1760 };
1761
1762 Ok(())
1763 }
1764
1765 #[test]
1766 fn sample_seq() {
1767 let logp = NormalLogp { dim: 10, mu: 0.1 };
1768 let math = CpuMath::new(&logp);
1769 let settings = DiagNutsSettings {
1770 num_tune: 100,
1771 num_draws: 100,
1772 ..Default::default()
1773 };
1774 let start = vec![0.2; 10];
1775
1776 let mut rng = StdRng::seed_from_u64(42);
1777
1778 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1779 let mut draws = chain.collect_vec();
1780 assert_eq!(draws.len(), 200);
1781
1782 let draw0 = draws.remove(100).unwrap();
1783 let (vals, stats) = draw0;
1784 assert_eq!(vals.len(), 10);
1785 assert_eq!(stats.chain, 1);
1786 assert_eq!(stats.draw, 100);
1787 }
1788}