1use anyhow::Result;
5use nuts_storable::{HasDims, Storable, Value};
6use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::{collections::HashMap, fmt::Debug, time::Duration};
9
10#[cfg(feature = "parallel")]
11use anyhow::{Context, bail};
12#[cfg(feature = "parallel")]
13use itertools::Itertools;
14#[cfg(feature = "parallel")]
15use std::ops::Deref;
16
17#[cfg(feature = "parallel")]
18use rayon::{ScopeFifo, ThreadPoolBuilder};
19#[cfg(feature = "parallel")]
20use std::{
21 sync::{
22 Arc, Mutex,
23 mpsc::{
24 Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
25 },
26 },
27 thread::{JoinHandle, spawn},
28 time::Instant,
29};
30
31use crate::{
32 DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
33 adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
34 chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
35 dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
36 external_adapt_strategy::{ExternalTransformAdaptation, FlowSettings},
37 mclmc::MclmcTrajectoryKind,
38 nuts::NutsOptions,
39 sampler_stats::{SamplerStats, StatsDims},
40 transform::{
41 DiagAdaptStrategy, DiagMassMatrix, ExternalTransformation, LowRankMassMatrix,
42 LowRankMassMatrixStrategy, LowRankSettings,
43 },
44};
45
46#[cfg(feature = "parallel")]
47use crate::{
48 model::Model,
49 storage::{ChainStorage, StorageConfig, TraceStorage},
50};
51
52pub trait Settings:
54 private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + DeserializeOwned + 'static
55{
56 type Chain<M: Math>: Chain<M>;
57
58 fn new_chain<M: Math, R: Rng + ?Sized>(
59 &self,
60 chain: u64,
61 math: M,
62 rng: &mut R,
63 ) -> Self::Chain<M>;
64
65 fn hint_num_tune(&self) -> usize;
66 fn hint_num_draws(&self) -> usize;
67 fn num_chains(&self) -> usize;
68 fn seed(&self) -> u64;
69 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
70 fn sampler_name(&self) -> &'static str;
71 fn adaptation_name(&self) -> &'static str;
72
73 fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
74 let dims = StatsDims::from(math);
75 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
76 .into_iter()
77 .map(String::from)
78 .collect()
79 }
80
81 fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
82 <M::ExpandedVector as Storable<_>>::names(math)
83 .into_iter()
84 .map(String::from)
85 .collect()
86 }
87
88 fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
89 self.stat_names(math)
90 .into_iter()
91 .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
92 .collect()
93 }
94
95 fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
96 let dims = StatsDims::from(math);
97 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
98 }
99
100 fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
101 self.data_names(math)
102 .into_iter()
103 .map(|name| (name.clone(), self.data_type(math, &name)))
104 .collect()
105 }
106 fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
107 <M::ExpandedVector as Storable<_>>::item_type(math, name)
108 }
109
110 fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
111 self.stat_names(math)
112 .into_iter()
113 .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
114 .collect()
115 }
116
117 fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
118 let dims = StatsDims::from(math);
119 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
120 .into_iter()
121 .map(String::from)
122 .collect()
123 }
124
125 fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
126 let dims = StatsDims::from(math);
127 dims.dim_sizes()
128 }
129
130 fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
131 self.data_names(math)
132 .into_iter()
133 .map(|name| (name.clone(), self.data_dims(math, &name)))
134 .collect()
135 }
136
137 fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
138 <M::ExpandedVector as Storable<_>>::dims(math, name)
139 .into_iter()
140 .map(String::from)
141 .collect()
142 }
143
144 fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
145 let dims = StatsDims::from(math);
146 dims.coords()
147 }
148
149 fn stat_event_dims<M: Math>(&self, math: &M) -> Vec<(String, Option<String>)> {
150 let dims = StatsDims::from(math);
151 self.stat_names(math)
152 .into_iter()
153 .map(|name| {
154 let event_dim =
155 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::event_dim(
156 &dims, &name,
157 )
158 .map(String::from);
159 (name, event_dim)
160 })
161 .collect()
162 }
163}
164
165#[derive(Debug, Clone)]
166#[non_exhaustive]
167pub struct Progress {
168 pub draw: u64,
169 pub chain: u64,
170 pub diverging: bool,
171 pub tuning: bool,
172 pub step_size: f64,
173 pub num_steps: u64,
174}
175
176mod private {
177 use super::{
178 DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, FlowNutsSettings,
179 LowRankMclmcSettings, LowRankNutsSettings,
180 };
181
182 pub trait Sealed {}
183
184 impl Sealed for DiagNutsSettings {}
185
186 impl Sealed for LowRankNutsSettings {}
187
188 impl Sealed for FlowNutsSettings {}
189
190 impl Sealed for DiagMclmcSettings {}
191
192 impl Sealed for LowRankMclmcSettings {}
193
194 impl Sealed for FlowMclmcSettings {}
195}
196
197#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
199pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
200 pub num_tune: u64,
202 pub num_draws: u64,
204 pub maxdepth: u64,
207 pub mindepth: u64,
210 pub store_gradient: bool,
212 pub store_unconstrained: bool,
214 pub store_transformed: bool,
216 pub max_energy_error: f64,
219 pub store_divergences: bool,
221 pub adapt_options: A,
223 pub check_turning: bool,
224 pub target_integration_time: Option<f64>,
225 pub trajectory_kind: KineticEnergyKind,
233 pub num_chains: usize,
234 pub seed: u64,
235 pub extra_doublings: u64,
239}
240
241pub type DiagNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
242#[deprecated(since = "0.0.0", note = "Use DiagNutsSettings instead")]
244pub type DiagGradNutsSettings = DiagNutsSettings;
245pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
246pub type FlowNutsSettings = NutsSettings<FlowSettings>;
247#[deprecated(since = "0.0.0", note = "Use FlowNutsSettings instead")]
249pub type TransformedNutsSettings = FlowNutsSettings;
250
251#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
267pub struct MclmcSettings<A: Debug + Copy + Default + Serialize> {
268 pub step_size: f64,
270 pub momentum_decoherence_length: f64,
273 pub num_tune: u64,
275 pub num_draws: u64,
277 pub num_chains: usize,
279 pub seed: u64,
281 pub max_energy_error: f64,
283 pub store_unconstrained: bool,
285 pub store_gradient: bool,
287 pub store_transformed: bool,
289 pub store_divergences: bool,
291 pub adapt_options: A,
293 pub subsample_frequency: f64,
303 pub dynamic_step_size: bool,
309 pub trajectory_kind: MclmcTrajectoryKind,
313 pub trajectory_switch_fraction: f64,
318}
319
320pub type DiagMclmcSettings = MclmcSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
325pub type LowRankMclmcSettings = MclmcSettings<EuclideanAdaptOptions<LowRankSettings>>;
330pub type FlowMclmcSettings = MclmcSettings<FlowSettings>;
335#[deprecated(since = "0.0.0", note = "Use FlowMclmcSettings instead")]
337pub type TransformedMclmcSettings = FlowMclmcSettings;
338
339fn usize_hint(value: u64, field: &str) -> usize {
340 value
341 .try_into()
342 .unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
343}
344
345fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
346 adapt_options: A,
347 num_tune: u64,
348 num_chains: usize,
349 max_energy_error: f64,
350) -> MclmcSettings<A> {
351 MclmcSettings {
352 step_size: 0.5,
353 momentum_decoherence_length: 3.0,
354 num_tune,
355 num_draws: 1000,
356 num_chains,
357 seed: 0,
358 max_energy_error,
359 store_unconstrained: false,
360 store_gradient: false,
361 store_divergences: false,
362 store_transformed: false,
363 adapt_options,
364 subsample_frequency: 1.0,
365 dynamic_step_size: true,
366 trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
367 trajectory_switch_fraction: 0.3,
368 }
369}
370
371impl Default for DiagMclmcSettings {
372 fn default() -> Self {
373 let mut adapt_options = EuclideanAdaptOptions::default();
374 adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
375 default_mclmc_settings(adapt_options, 400, 6, 1000.0)
376 }
377}
378
379impl Default for LowRankMclmcSettings {
380 fn default() -> Self {
381 let mut adapt_options = EuclideanAdaptOptions::default();
382 adapt_options.early_mass_matrix_switch_freq = 20;
383 adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
384 default_mclmc_settings(adapt_options, 800, 6, 1000.0)
385 }
386}
387
388impl Default for FlowMclmcSettings {
389 fn default() -> Self {
390 default_mclmc_settings(FlowSettings::default(), 1500, 1, 20.0)
391 }
392}
393
394type DiagMclmcChain<M> = crate::mclmc::MclmcChain<
395 M,
396 ChaCha8Rng,
397 GlobalStrategy<M, DiagAdaptStrategy<M>>,
398 DiagMassMatrix<M>,
399>;
400type LowRankMclmcChain<M> = crate::mclmc::MclmcChain<
401 M,
402 ChaCha8Rng,
403 GlobalStrategy<M, LowRankMassMatrixStrategy>,
404 LowRankMassMatrix<M>,
405>;
406
407impl Settings for DiagMclmcSettings {
408 type Chain<M: Math> = DiagMclmcChain<M>;
409
410 fn new_chain<M: Math, R: Rng + ?Sized>(
411 &self,
412 chain: u64,
413 mut math: M,
414 rng: &mut R,
415 ) -> Self::Chain<M> {
416 use crate::dynamics::KineticEnergyKind;
417 use crate::mclmc::MclmcChain;
418 use crate::stepsize::StepSizeAdaptMethod;
419
420 let num_tune = self.num_tune;
421 let mut adapt_options = self.adapt_options;
422 adapt_options.step_size_settings.adapt_options.method =
423 StepSizeAdaptMethod::Fixed(self.step_size);
424 let strategy = GlobalStrategy::<M, DiagAdaptStrategy<M>>::new(
425 &mut math,
426 adapt_options,
427 num_tune,
428 chain,
429 );
430 let mass_matrix = DiagMassMatrix::new(
431 &mut math,
432 self.adapt_options.mass_matrix_options.store_mass_matrix,
433 );
434 let initial_kind = match self.trajectory_kind {
435 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
436 MclmcTrajectoryKind::Euclidean
437 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
438 };
439 let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
440 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
441 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
442 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
443 let stats_options = self.stats_options::<M>();
444 MclmcChain::new(
445 math,
446 hamiltonian,
447 strategy,
448 rng,
449 chain,
450 self.subsample_frequency,
451 self.dynamic_step_size,
452 self.trajectory_kind,
453 switch_draw,
454 self.max_energy_error,
455 stats_options,
456 )
457 }
458
459 fn hint_num_tune(&self) -> usize {
460 usize_hint(self.num_tune, "num_tune")
461 }
462
463 fn hint_num_draws(&self) -> usize {
464 usize_hint(self.num_draws, "num_draws")
465 }
466
467 fn num_chains(&self) -> usize {
468 self.num_chains
469 }
470
471 fn seed(&self) -> u64 {
472 self.seed
473 }
474
475 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
476 StatOptions {
477 adapt: GlobalStrategyStatsOptions {
478 step_size: (),
479 mass_matrix: (),
480 },
481 hamiltonian: -1,
482 point: {
483 let store_gradient = self.store_gradient;
484 let store_unconstrained = self.store_unconstrained;
485 let store_transformed = self.store_transformed;
486 TransformedPointStatsOptions {
487 store_gradient,
488 store_unconstrained,
489 store_transformed,
490 }
491 },
492 divergence: crate::dynamics::DivergenceStatsOptions {
493 store_divergences: self.store_divergences,
494 },
495 }
496 }
497
498 fn sampler_name(&self) -> &'static str {
499 "mclmc"
500 }
501
502 fn adaptation_name(&self) -> &'static str {
503 "diagonal"
504 }
505}
506
507fn default_nuts_settings<A: Debug + Copy + Default + Serialize>(
508 adapt_options: A,
509 num_tune: u64,
510 num_chains: usize,
511 max_energy_error: f64,
512) -> NutsSettings<A> {
513 NutsSettings {
514 num_tune,
515 num_draws: 1000,
516 maxdepth: 10,
517 mindepth: 0,
518 max_energy_error,
519 store_gradient: false,
520 store_unconstrained: false,
521 store_transformed: false,
522 store_divergences: false,
523 adapt_options,
524 check_turning: true,
525 seed: 0,
526 num_chains,
527 target_integration_time: None,
528 trajectory_kind: KineticEnergyKind::Euclidean,
529 extra_doublings: 0,
530 }
531}
532
533impl Settings for LowRankMclmcSettings {
534 type Chain<M: Math> = LowRankMclmcChain<M>;
535
536 fn new_chain<M: Math, R: Rng + ?Sized>(
537 &self,
538 chain: u64,
539 mut math: M,
540 rng: &mut R,
541 ) -> Self::Chain<M> {
542 use crate::dynamics::KineticEnergyKind;
543 use crate::mclmc::MclmcChain;
544 use crate::stepsize::StepSizeAdaptMethod;
545
546 let num_tune = self.num_tune;
547 let mut adapt_options = self.adapt_options;
548 adapt_options.step_size_settings.adapt_options.method =
549 StepSizeAdaptMethod::Fixed(self.step_size);
550 let strategy = GlobalStrategy::<M, LowRankMassMatrixStrategy>::new(
551 &mut math,
552 adapt_options,
553 num_tune,
554 chain,
555 );
556 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
557 let initial_kind = match self.trajectory_kind {
558 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
559 MclmcTrajectoryKind::Euclidean
560 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
561 };
562 let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
563 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
564 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
565 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
566 let stats_options = self.stats_options::<M>();
567 MclmcChain::new(
568 math,
569 hamiltonian,
570 strategy,
571 rng,
572 chain,
573 self.subsample_frequency,
574 self.dynamic_step_size,
575 self.trajectory_kind,
576 switch_draw,
577 self.max_energy_error,
578 stats_options,
579 )
580 }
581
582 fn hint_num_tune(&self) -> usize {
583 usize_hint(self.num_tune, "num_tune")
584 }
585
586 fn hint_num_draws(&self) -> usize {
587 usize_hint(self.num_draws, "num_draws")
588 }
589
590 fn num_chains(&self) -> usize {
591 self.num_chains
592 }
593
594 fn seed(&self) -> u64 {
595 self.seed
596 }
597
598 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
599 StatOptions {
600 adapt: GlobalStrategyStatsOptions {
601 step_size: (),
602 mass_matrix: (),
603 },
604 hamiltonian: -1,
605 point: {
606 let store_gradient = self.store_gradient;
607 let store_unconstrained = self.store_unconstrained;
608 let store_transformed = self.store_transformed;
609 TransformedPointStatsOptions {
610 store_gradient,
611 store_unconstrained,
612 store_transformed,
613 }
614 },
615 divergence: crate::dynamics::DivergenceStatsOptions {
616 store_divergences: self.store_divergences,
617 },
618 }
619 }
620
621 fn sampler_name(&self) -> &'static str {
622 "mclmc"
623 }
624
625 fn adaptation_name(&self) -> &'static str {
626 "low_rank"
627 }
628}
629
630impl Default for DiagNutsSettings {
631 fn default() -> Self {
632 default_nuts_settings(EuclideanAdaptOptions::default(), 400, 6, 1000.0)
633 }
634}
635
636impl Default for LowRankNutsSettings {
637 fn default() -> Self {
638 let mut vals = default_nuts_settings(EuclideanAdaptOptions::default(), 800, 6, 1000.0);
639 vals.adapt_options.mass_matrix_update_freq = 20;
640 vals
641 }
642}
643
644impl Default for FlowNutsSettings {
645 fn default() -> Self {
646 default_nuts_settings(FlowSettings::default(), 1500, 1, 20.0)
647 }
648}
649
650type DiagNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, DiagAdaptStrategy<M>>>;
651type LowRankNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
652
653fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>) -> NutsOptions {
654 NutsOptions {
655 maxdepth: settings.maxdepth,
656 mindepth: settings.mindepth,
657 store_divergences: settings.store_divergences,
658 check_turning: settings.check_turning,
659 target_integration_time: settings.target_integration_time,
660 extra_doublings: settings.extra_doublings,
661 max_energy_error: settings.max_energy_error,
662 }
663}
664
665impl Settings for LowRankNutsSettings {
666 type Chain<M: Math> = LowRankNutsChain<M>;
667
668 fn new_chain<M: Math, R: Rng + ?Sized>(
669 &self,
670 chain: u64,
671 mut math: M,
672 mut rng: &mut R,
673 ) -> Self::Chain<M> {
674 let num_tune = self.num_tune;
675 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
676 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
677 let hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
678
679 let options = nuts_options(self);
680
681 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
682
683 NutsChain::new(
684 math,
685 hamiltonian,
686 strategy,
687 options,
688 rng,
689 chain,
690 self.stats_options(),
691 )
692 }
693
694 fn hint_num_tune(&self) -> usize {
695 usize_hint(self.num_tune, "num_tune")
696 }
697
698 fn hint_num_draws(&self) -> usize {
699 usize_hint(self.num_draws, "num_draws")
700 }
701
702 fn num_chains(&self) -> usize {
703 self.num_chains
704 }
705
706 fn seed(&self) -> u64 {
707 self.seed
708 }
709
710 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
711 StatOptions {
712 adapt: GlobalStrategyStatsOptions {
713 mass_matrix: (),
714 step_size: (),
715 },
716 hamiltonian: -1,
717 point: {
718 let store_gradient = self.store_gradient;
719 let store_unconstrained = self.store_unconstrained;
720 let store_transformed = self.store_transformed;
721 TransformedPointStatsOptions {
722 store_gradient,
723 store_unconstrained,
724 store_transformed,
725 }
726 },
727 divergence: crate::dynamics::DivergenceStatsOptions {
728 store_divergences: self.store_divergences,
729 },
730 }
731 }
732
733 fn sampler_name(&self) -> &'static str {
734 "nuts"
735 }
736
737 fn adaptation_name(&self) -> &'static str {
738 "low_rank"
739 }
740}
741
742impl Settings for DiagNutsSettings {
743 type Chain<M: Math> = DiagNutsChain<M>;
744
745 fn new_chain<M: Math, R: Rng + ?Sized>(
746 &self,
747 chain: u64,
748 mut math: M,
749 mut rng: &mut R,
750 ) -> Self::Chain<M> {
751 let num_tune = self.num_tune;
752 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
753 let mass_matrix = DiagMassMatrix::new(
754 &mut math,
755 self.adapt_options.mass_matrix_options.store_mass_matrix,
756 );
757 let potential = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
758
759 let options = nuts_options(self);
760
761 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
762
763 NutsChain::new(
764 math,
765 potential,
766 strategy,
767 options,
768 rng,
769 chain,
770 self.stats_options(),
771 )
772 }
773
774 fn hint_num_tune(&self) -> usize {
775 usize_hint(self.num_tune, "num_tune")
776 }
777
778 fn hint_num_draws(&self) -> usize {
779 usize_hint(self.num_draws, "num_draws")
780 }
781
782 fn num_chains(&self) -> usize {
783 self.num_chains
784 }
785
786 fn seed(&self) -> u64 {
787 self.seed
788 }
789
790 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
791 StatOptions {
792 adapt: GlobalStrategyStatsOptions {
793 mass_matrix: (),
794 step_size: (),
795 },
796 hamiltonian: -1,
797 point: {
798 let store_gradient = self.store_gradient;
799 let store_unconstrained = self.store_unconstrained;
800 let store_transformed = self.store_transformed;
801 TransformedPointStatsOptions {
802 store_gradient,
803 store_unconstrained,
804 store_transformed,
805 }
806 },
807 divergence: crate::dynamics::DivergenceStatsOptions {
808 store_divergences: self.store_divergences,
809 },
810 }
811 }
812
813 fn sampler_name(&self) -> &'static str {
814 "nuts"
815 }
816
817 fn adaptation_name(&self) -> &'static str {
818 "diagonal"
819 }
820}
821
822impl Settings for FlowNutsSettings {
823 type Chain<M: Math> = NutsChain<M, ChaCha8Rng, ExternalTransformAdaptation>;
824
825 fn new_chain<M: Math, R: Rng + ?Sized>(
826 &self,
827 chain: u64,
828 mut math: M,
829 mut rng: &mut R,
830 ) -> Self::Chain<M> {
831 let num_tune = self.num_tune;
832
833 let strategy =
834 ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
835 let params = math
836 .new_transformation(rng, math.dim(), chain)
837 .expect("Failed to create external transformation");
838 let transform = ExternalTransformation::new(params);
839 let hamiltonian = TransformedHamiltonian::new(&mut math, transform, self.trajectory_kind);
840
841 let options = nuts_options(self);
842
843 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
844 NutsChain::new(
845 math,
846 hamiltonian,
847 strategy,
848 options,
849 rng,
850 chain,
851 self.stats_options(),
852 )
853 }
854
855 fn hint_num_tune(&self) -> usize {
856 usize_hint(self.num_tune, "num_tune")
857 }
858
859 fn hint_num_draws(&self) -> usize {
860 usize_hint(self.num_draws, "num_draws")
861 }
862
863 fn num_chains(&self) -> usize {
864 self.num_chains
865 }
866
867 fn seed(&self) -> u64 {
868 self.seed
869 }
870
871 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
872 StatOptions {
873 adapt: (),
874 hamiltonian: (),
875 point: {
876 let store_gradient = self.store_gradient;
877 let store_unconstrained = self.store_unconstrained;
878 let store_transformed = self.store_transformed;
879 TransformedPointStatsOptions {
880 store_gradient,
881 store_unconstrained,
882 store_transformed,
883 }
884 },
885 divergence: crate::dynamics::DivergenceStatsOptions {
886 store_divergences: self.store_divergences,
887 },
888 }
889 }
890
891 fn sampler_name(&self) -> &'static str {
892 "nuts"
893 }
894
895 fn adaptation_name(&self) -> &'static str {
896 "flow"
897 }
898}
899
900impl Settings for FlowMclmcSettings {
901 type Chain<M: Math> = crate::mclmc::MclmcChain<
902 M,
903 ChaCha8Rng,
904 ExternalTransformAdaptation,
905 ExternalTransformation<M>,
906 >;
907
908 fn new_chain<M: Math, R: Rng + ?Sized>(
909 &self,
910 chain: u64,
911 mut math: M,
912 rng: &mut R,
913 ) -> Self::Chain<M> {
914 use crate::dynamics::KineticEnergyKind;
915 use crate::mclmc::MclmcChain;
916
917 let num_tune = self.num_tune;
918 let strategy =
919 ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
920 let params = math
921 .new_transformation(rng, math.dim(), chain)
922 .expect("Failed to create external transformation");
923 let transform = ExternalTransformation::new(params);
924 let initial_kind = match self.trajectory_kind {
925 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
926 MclmcTrajectoryKind::Euclidean
927 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
928 };
929 let mut hamiltonian = TransformedHamiltonian::new(&mut math, transform, initial_kind);
930 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
931 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
932 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
933 let stats_options = self.stats_options::<M>();
934 MclmcChain::new(
935 math,
936 hamiltonian,
937 strategy,
938 rng,
939 chain,
940 self.subsample_frequency,
941 self.dynamic_step_size,
942 self.trajectory_kind,
943 switch_draw,
944 self.max_energy_error,
945 stats_options,
946 )
947 }
948
949 fn hint_num_tune(&self) -> usize {
950 usize_hint(self.num_tune, "num_tune")
951 }
952
953 fn hint_num_draws(&self) -> usize {
954 usize_hint(self.num_draws, "num_draws")
955 }
956
957 fn num_chains(&self) -> usize {
958 self.num_chains
959 }
960
961 fn seed(&self) -> u64 {
962 self.seed
963 }
964
965 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
966 StatOptions {
967 adapt: (),
968 hamiltonian: (),
969 point: {
970 let store_gradient = self.store_gradient;
971 let store_unconstrained = self.store_unconstrained;
972 let store_transformed = self.store_transformed;
973 TransformedPointStatsOptions {
974 store_gradient,
975 store_unconstrained,
976 store_transformed,
977 }
978 },
979 divergence: crate::dynamics::DivergenceStatsOptions {
980 store_divergences: self.store_divergences,
981 },
982 }
983 }
984
985 fn sampler_name(&self) -> &'static str {
986 "mclmc"
987 }
988
989 fn adaptation_name(&self) -> &'static str {
990 "flow"
991 }
992}
993
994pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
995 math: M,
996 settings: DiagNutsSettings,
997 start: &[f64],
998 draws: u64,
999 chain: u64,
1000 rng: &mut R,
1001) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
1002 let mut sampler = settings.new_chain(chain, math, rng);
1003 sampler.set_position(start)?;
1004 Ok((0..draws).map(move |_| sampler.draw()))
1005}
1006
1007#[non_exhaustive]
1008#[derive(Clone, Debug)]
1009pub struct ChainProgress {
1010 pub finished_draws: usize,
1011 pub total_draws: usize,
1012 pub divergences: usize,
1013 pub tuning: bool,
1014 pub started: bool,
1015 pub latest_num_steps: usize,
1016 pub total_num_steps: usize,
1017 pub step_size: f64,
1018 pub runtime: Duration,
1019 pub divergent_draws: Vec<usize>,
1020}
1021
1022impl ChainProgress {
1023 fn new(total: usize) -> Self {
1024 Self {
1025 finished_draws: 0,
1026 total_draws: total,
1027 divergences: 0,
1028 tuning: true,
1029 started: false,
1030 latest_num_steps: 0,
1031 step_size: 0f64,
1032 total_num_steps: 0,
1033 runtime: Duration::ZERO,
1034 divergent_draws: Vec::new(),
1035 }
1036 }
1037
1038 fn update(&mut self, stats: &Progress, draw_duration: Duration) {
1039 if stats.diverging & !stats.tuning {
1040 self.divergences += 1;
1041 self.divergent_draws.push(self.finished_draws);
1042 }
1043 self.finished_draws += 1;
1044 self.tuning = stats.tuning;
1045
1046 self.latest_num_steps = stats.num_steps as usize;
1047 self.total_num_steps += stats.num_steps as usize;
1048 self.step_size = stats.step_size;
1049 self.runtime += draw_duration;
1050 }
1051}
1052
1053#[cfg(feature = "parallel")]
1054enum ChainCommand {
1055 Resume,
1056 Pause,
1057}
1058
1059#[cfg(feature = "parallel")]
1060struct ChainProcess<T>
1061where
1062 T: TraceStorage,
1063{
1064 stop_marker: Sender<ChainCommand>,
1065 trace: Arc<Mutex<Option<T::ChainStorage>>>,
1066 progress: Arc<Mutex<ChainProgress>>,
1067}
1068
1069#[cfg(feature = "parallel")]
1070impl<T: TraceStorage> ChainProcess<T> {
1071 fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
1072 let finalized_chain_traces = chains
1073 .into_iter()
1074 .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
1075 .map(|chain| chain.finalize())
1076 .collect_vec();
1077 trace.finalize(finalized_chain_traces)
1078 }
1079
1080 fn progress(&self) -> ChainProgress {
1081 self.progress.lock().expect("Poisoned lock").clone()
1082 }
1083
1084 fn resume(&self) -> Result<()> {
1085 self.stop_marker.send(ChainCommand::Resume)?;
1086 Ok(())
1087 }
1088
1089 fn pause(&self) -> Result<()> {
1090 self.stop_marker.send(ChainCommand::Pause)?;
1091 Ok(())
1092 }
1093
1094 fn start<'model, M: Model, S: Settings>(
1095 model: &'model M,
1096 chain_trace: T::ChainStorage,
1097 chain_id: u64,
1098 seed: u64,
1099 settings: &'model S,
1100 scope: &ScopeFifo<'model>,
1101 results: Sender<Result<()>>,
1102 ) -> Result<Self> {
1103 let (stop_marker_tx, stop_marker_rx) = channel();
1104
1105 let mut rng = ChaCha8Rng::seed_from_u64(seed);
1106 rng.set_stream(chain_id + 1);
1107
1108 let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
1109 let progress = Arc::new(Mutex::new(ChainProgress::new(
1110 settings.hint_num_draws() + settings.hint_num_tune(),
1111 )));
1112
1113 let trace_inner = chain_trace.clone();
1114 let progress_inner = progress.clone();
1115
1116 scope.spawn_fifo(move |_| {
1117 let chain_trace = trace_inner;
1118 let progress = progress_inner;
1119
1120 let mut sample = move || {
1121 let logp = model
1122 .math(&mut rng)
1123 .context("Failed to create model density")?;
1124 let dim = logp.dim();
1125
1126 let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
1127
1128 progress.lock().expect("Poisoned mutex").started = true;
1129
1130 let mut initval = vec![0f64; dim];
1131 let mut error = None;
1133 for _ in 0..500 {
1134 model
1135 .init_position(&mut rng, &mut initval)
1136 .context("Failed to generate a new initial position")?;
1137 if let Err(err) = sampler.set_position(&initval) {
1138 error = Some(err);
1139 continue;
1140 }
1141 error = None;
1142 break;
1143 }
1144
1145 if let Some(error) = error {
1146 return Err(error.context("All initialization points failed"));
1147 }
1148
1149 let draws = settings.hint_num_tune() + settings.hint_num_draws();
1150
1151 let mut msg = stop_marker_rx.try_recv();
1152 let mut draw = 0;
1153 loop {
1154 match msg {
1155 Err(TryRecvError::Disconnected) => {
1157 break;
1158 }
1159 Err(TryRecvError::Empty) => {}
1160 Ok(ChainCommand::Pause) => {
1161 msg = stop_marker_rx.recv().map_err(|e| e.into());
1162 continue;
1163 }
1164 Ok(ChainCommand::Resume) => {}
1165 }
1166
1167 let now = Instant::now();
1168 let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
1169
1170 let mut guard = chain_trace
1171 .lock()
1172 .expect("Could not unlock trace lock. Poisoned mutex");
1173
1174 let Some(trace_val) = guard.as_mut() else {
1175 break;
1177 };
1178 progress
1179 .lock()
1180 .expect("Poisoned mutex")
1181 .update(&info, now.elapsed());
1182
1183 let math = sampler.math();
1184 let dims = StatsDims::from(math.deref());
1185 trace_val.record_sample(
1186 settings,
1187 stats.get_all(&dims),
1188 draw_data.get_all(math.deref()),
1189 &info,
1190 )?;
1191
1192 draw += 1;
1193 if draw == draws {
1194 break;
1195 }
1196
1197 msg = stop_marker_rx.try_recv();
1198 }
1199 Ok(())
1200 };
1201
1202 let result = sample();
1203
1204 let _ = results.send(result);
1207 drop(results);
1208 });
1209
1210 Ok(Self {
1211 trace: chain_trace,
1212 stop_marker: stop_marker_tx,
1213 progress,
1214 })
1215 }
1216
1217 fn flush(&self) -> Result<()> {
1218 self.trace
1219 .lock()
1220 .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
1221 .context("Could not flush trace")?
1222 .as_mut()
1223 .map(|v| v.flush())
1224 .transpose()?;
1225 Ok(())
1226 }
1227}
1228
1229#[cfg(feature = "parallel")]
1230#[derive(Debug)]
1231enum SamplerCommand {
1232 Pause,
1233 Continue,
1234 Progress,
1235 Flush,
1236 Inspect,
1237}
1238
1239#[cfg(feature = "parallel")]
1240enum SamplerResponse<T: Send + 'static> {
1241 Ok(),
1242 Progress(Box<[ChainProgress]>),
1243 Inspect(T),
1244}
1245
1246#[cfg(feature = "parallel")]
1247pub enum SamplerWaitResult<F: Send + 'static> {
1248 Trace(F),
1249 Timeout(Sampler<F>),
1250 Err(anyhow::Error, Option<F>),
1251}
1252
1253#[cfg(feature = "parallel")]
1254pub struct Sampler<F: Send + 'static> {
1255 main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
1256 commands: SyncSender<SamplerCommand>,
1257 responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
1258 results: Receiver<Result<()>>,
1259}
1260
1261#[cfg(feature = "parallel")]
1262pub struct ProgressCallback {
1263 pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
1264 pub rate: Duration,
1265}
1266
1267#[cfg(feature = "parallel")]
1268impl<F: Send + 'static> Sampler<F> {
1269 pub fn new<M, S, C, T>(
1270 model: M,
1271 settings: S,
1272 trace_config: C,
1273 num_cores: usize,
1274 callback: Option<ProgressCallback>,
1275 ) -> Result<Self>
1276 where
1277 S: Settings,
1278 C: StorageConfig<Storage = T>,
1279 M: Model,
1280 T: TraceStorage<Finalized = F>,
1281 {
1282 let (commands_tx, commands_rx) = sync_channel(0);
1283 let (responses_tx, responses_rx) = sync_channel(0);
1284 let (results_tx, results_rx) = channel();
1285
1286 let main_thread = spawn(move || {
1287 let pool = ThreadPoolBuilder::new()
1288 .num_threads(num_cores + 1) .thread_name(|i| format!("nutpie-worker-{i}"))
1290 .build()
1291 .context("Could not start thread pool")?;
1292
1293 let settings_ref = &settings;
1294 let model_ref = &model;
1295 let mut callback = callback;
1296
1297 pool.scope_fifo(move |scope| {
1298 let results = results_tx;
1299 let mut chains = Vec::with_capacity(settings.num_chains());
1300
1301 let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
1302 rng.set_stream(0);
1303
1304 let math = model_ref
1305 .math(&mut rng)
1306 .context("Could not create model density")?;
1307 let trace = trace_config
1308 .new_trace(settings_ref, &math)
1309 .context("Could not create trace object")?;
1310 drop(math);
1311
1312 for chain_id in 0..settings.num_chains() {
1313 let chain_trace_val = trace
1314 .initialize_trace_for_chain(chain_id as u64)
1315 .context("Failed to create trace object")?;
1316 let chain = ChainProcess::start(
1317 model_ref,
1318 chain_trace_val,
1319 chain_id as u64,
1320 settings.seed(),
1321 settings_ref,
1322 scope,
1323 results.clone(),
1324 );
1325 chains.push(chain);
1326 }
1327 drop(results);
1328
1329 let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
1330 if let Some(error) = errors.into_iter().next() {
1331 let _ = ChainProcess::finalize_many(trace, chains);
1332 return Err(error).context("Could not start chains");
1333 }
1334
1335 let mut main_loop = || {
1336 let start_time = Instant::now();
1337 let mut pause_start = Instant::now();
1338 let mut pause_time = Duration::ZERO;
1339
1340 let mut progress_rate = Duration::MAX;
1341 if let Some(ProgressCallback { callback, rate }) = &mut callback {
1342 let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
1343 callback(start_time.elapsed(), progress.into());
1344 progress_rate = *rate;
1345 }
1346 let mut last_progress = Instant::now();
1347 let mut is_paused = false;
1348
1349 loop {
1350 let timeout = progress_rate.checked_sub(last_progress.elapsed());
1351 let timeout = timeout.unwrap_or_else(|| {
1352 if let Some(ProgressCallback { callback, .. }) = &mut callback {
1353 let progress =
1354 chains.iter().map(|chain| chain.progress()).collect_vec();
1355 let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
1356 if is_paused {
1357 elapsed = elapsed.saturating_sub(pause_start.elapsed());
1358 }
1359 callback(elapsed, progress.into());
1360 }
1361 last_progress = Instant::now();
1362 progress_rate
1363 });
1364
1365 match commands_rx.recv_timeout(timeout) {
1367 Ok(SamplerCommand::Pause) => {
1368 for chain in chains.iter() {
1369 let _ = chain.pause();
1372 }
1373 if !is_paused {
1374 pause_start = Instant::now();
1375 }
1376 is_paused = true;
1377 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1378 anyhow::anyhow!(
1379 "Could not send pause response to controller thread: {e}"
1380 )
1381 })?;
1382 }
1383 Ok(SamplerCommand::Continue) => {
1384 for chain in chains.iter() {
1385 let _ = chain.resume();
1388 }
1389 pause_time += pause_start.elapsed();
1390 is_paused = false;
1391 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1392 anyhow::anyhow!(
1393 "Could not send continue response to controller thread: {e}"
1394 )
1395 })?;
1396 }
1397 Ok(SamplerCommand::Progress) => {
1398 let progress =
1399 chains.iter().map(|chain| chain.progress()).collect_vec();
1400 responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
1401 anyhow::anyhow!(
1402 "Could not send progress response to controller thread: {e}"
1403 )
1404 })?;
1405 }
1406 Ok(SamplerCommand::Inspect) => {
1407 let traces = chains
1408 .iter()
1409 .filter_map(|chain| {
1410 chain
1411 .trace
1412 .lock()
1413 .expect("Poisoned lock")
1414 .as_ref()
1415 .map(|v| v.inspect())
1416 })
1417 .collect_vec();
1418 let finalized_trace = trace.inspect(traces)?;
1419 responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
1420 anyhow::anyhow!(
1421 "Could not send inspect response to controller thread: {e}"
1422 )
1423 })?;
1424 }
1425 Ok(SamplerCommand::Flush) => {
1426 for chain in chains.iter() {
1427 chain.flush()?;
1428 }
1429 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1430 anyhow::anyhow!(
1431 "Could not send flush response to controller thread: {e}"
1432 )
1433 })?;
1434 }
1435 Err(RecvTimeoutError::Timeout) => {}
1436 Err(RecvTimeoutError::Disconnected) => {
1437 if let Some(ProgressCallback { callback, .. }) = &mut callback {
1438 let progress =
1439 chains.iter().map(|chain| chain.progress()).collect_vec();
1440 let mut elapsed =
1441 start_time.elapsed().saturating_sub(pause_time);
1442 if is_paused {
1443 elapsed = elapsed.saturating_sub(pause_start.elapsed());
1444 }
1445 callback(elapsed, progress.into());
1446 }
1447 return Ok(());
1448 }
1449 };
1450 }
1451 };
1452 let result: Result<()> = main_loop();
1453 let output = ChainProcess::finalize_many(trace, chains)?;
1455
1456 result?;
1457 Ok(output)
1458 })
1459 });
1460
1461 Ok(Self {
1462 main_thread,
1463 commands: commands_tx,
1464 responses: responses_rx,
1465 results: results_rx,
1466 })
1467 }
1468
1469 pub fn pause(&mut self) -> Result<()> {
1470 self.commands
1471 .send(SamplerCommand::Pause)
1472 .context("Could not send pause command to controller thread")?;
1473 let response = self
1474 .responses
1475 .recv()
1476 .context("Could not recieve pause response from controller thread")?;
1477 let SamplerResponse::Ok() = response else {
1478 bail!("Got invalid response from sample controller thread");
1479 };
1480 Ok(())
1481 }
1482
1483 pub fn resume(&mut self) -> Result<()> {
1484 self.commands.send(SamplerCommand::Continue)?;
1485 let response = self.responses.recv()?;
1486 let SamplerResponse::Ok() = response else {
1487 bail!("Got invalid response from sample controller thread");
1488 };
1489 Ok(())
1490 }
1491
1492 pub fn flush(&mut self) -> Result<()> {
1493 self.commands.send(SamplerCommand::Flush)?;
1494 let response = self
1495 .responses
1496 .recv()
1497 .context("Could not recieve flush response from controller thread")?;
1498 let SamplerResponse::Ok() = response else {
1499 bail!("Got invalid response from sample controller thread");
1500 };
1501 Ok(())
1502 }
1503
1504 pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
1505 self.commands.send(SamplerCommand::Inspect)?;
1506 let response = self
1507 .responses
1508 .recv()
1509 .context("Could not recieve inspect response from controller thread")?;
1510 let SamplerResponse::Inspect(trace) = response else {
1511 bail!("Got invalid response from sample controller thread");
1512 };
1513 Ok(trace)
1514 }
1515
1516 pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
1517 drop(self.commands);
1518 let result = self.main_thread.join();
1519 match result {
1520 Err(payload) => std::panic::resume_unwind(payload),
1521 Ok(Ok(val)) => Ok(val),
1522 Ok(Err(err)) => Err(err),
1523 }
1524 }
1525
1526 pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
1527 let start = Instant::now();
1528 let mut remaining = Some(timeout);
1529 while remaining.is_some() {
1530 match self.results.recv_timeout(timeout) {
1531 Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
1532 Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
1533 Err(RecvTimeoutError::Disconnected) => match self.abort() {
1534 Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
1535 Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
1536 Err(err) => return SamplerWaitResult::Err(err, None),
1537 },
1538 Err(RecvTimeoutError::Timeout) => break,
1539 }
1540 }
1541 SamplerWaitResult::Timeout(self)
1542 }
1543
1544 pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1545 self.commands.send(SamplerCommand::Progress)?;
1546 let response = self.responses.recv()?;
1547 let SamplerResponse::Progress(progress) = response else {
1548 bail!("Got invalid response from sample controller thread");
1549 };
1550 Ok(progress)
1551 }
1552}
1553
1554#[cfg(test)]
1555pub mod test_logps {
1556 use crate::{Model, math::CpuLogpFunc, math::CpuMath};
1557 use anyhow::Result;
1558 use rand::Rng;
1559
1560 pub struct CpuModel<F> {
1561 logp: F,
1562 }
1563
1564 impl<F> CpuModel<F> {
1565 pub fn new(logp: F) -> Self {
1566 Self { logp }
1567 }
1568 }
1569
1570 impl<F> Model for CpuModel<F>
1571 where
1572 F: Send + Sync + 'static,
1573 for<'a> &'a F: CpuLogpFunc,
1574 {
1575 type Math<'model> = CpuMath<&'model F>;
1576
1577 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1578 Ok(CpuMath::new(&self.logp))
1579 }
1580
1581 fn init_position<R: rand::prelude::Rng + ?Sized>(
1582 &self,
1583 _rng: &mut R,
1584 position: &mut [f64],
1585 ) -> Result<()> {
1586 position.iter_mut().for_each(|x| *x = 0.);
1587 Ok(())
1588 }
1589 }
1590}
1591
1592#[cfg(test)]
1593mod tests {
1594 use crate::math::test_logps::NormalLogp;
1595 use crate::{
1596 Chain, math::CpuMath, sample_sequentially, sampler::DiagMclmcSettings,
1597 sampler::DiagNutsSettings, sampler::LowRankMclmcSettings, sampler::LowRankNutsSettings,
1598 sampler::Settings,
1599 };
1600
1601 #[cfg(feature = "zarr")]
1602 use super::test_logps::CpuModel;
1603
1604 use anyhow::Result;
1605 use itertools::Itertools;
1606 use pretty_assertions::assert_eq;
1607 use rand::{SeedableRng, rngs::StdRng};
1608
1609 #[cfg(feature = "zarr")]
1610 use std::{
1611 sync::Arc,
1612 time::{Duration, Instant},
1613 };
1614
1615 #[cfg(feature = "zarr")]
1616 use crate::{Sampler, ZarrConfig};
1617
1618 #[cfg(feature = "zarr")]
1619 use zarrs::storage::store::MemoryStore;
1620
1621 fn assert_settings_smoke<S: Settings>(settings: S) -> Result<()> {
1622 let logp = NormalLogp { dim: 4, mu: 0.1 };
1623 let math = CpuMath::new(&logp);
1624 let mut rng = StdRng::seed_from_u64(42);
1625
1626 let stat_names = settings.stat_names(&math);
1627 let stat_types = settings.stat_types(&math);
1628 assert!(!stat_names.is_empty());
1629 assert_eq!(stat_names.len(), stat_types.len());
1630
1631 let mut chain = settings.new_chain(0, math, &mut rng);
1632 chain.set_position(&vec![0.2; 4])?;
1633 let (_draw, _info) = chain.draw()?;
1634 Ok(())
1635 }
1636
1637 #[test]
1638 fn all_settings_smoke() -> Result<()> {
1639 assert_settings_smoke(DiagNutsSettings {
1640 num_tune: 10,
1641 num_draws: 10,
1642 ..Default::default()
1643 })?;
1644 assert_settings_smoke(LowRankNutsSettings {
1645 num_tune: 10,
1646 num_draws: 10,
1647 ..Default::default()
1648 })?;
1649 assert_settings_smoke(DiagMclmcSettings {
1650 num_tune: 10,
1651 num_draws: 10,
1652 ..Default::default()
1653 })?;
1654 assert_settings_smoke(LowRankMclmcSettings {
1655 num_tune: 10,
1656 num_draws: 10,
1657 ..Default::default()
1658 })?;
1659 Ok(())
1660 }
1661
1662 #[test]
1663 fn sample_chain() -> Result<()> {
1664 let logp = NormalLogp { dim: 10, mu: 0.1 };
1665 let math = CpuMath::new(&logp);
1666 let settings = DiagNutsSettings {
1667 num_tune: 100,
1668 num_draws: 100,
1669 ..Default::default()
1670 };
1671 let start = vec![0.2; 10];
1672
1673 let mut rng = StdRng::seed_from_u64(42);
1674
1675 let mut chain = settings.new_chain(0, math, &mut rng);
1676
1677 let (_draw, info) = chain.draw()?;
1678 assert!(info.tuning);
1679 assert_eq!(info.draw, 0);
1680
1681 let math = CpuMath::new(&logp);
1682 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1683 let mut draws = chain.collect_vec();
1684 assert_eq!(draws.len(), 200);
1685
1686 let draw0 = draws.remove(100).unwrap();
1687 let (vals, stats) = draw0;
1688 assert_eq!(vals.len(), 10);
1689 assert_eq!(stats.chain, 1);
1690 assert_eq!(stats.draw, 100);
1691 Ok(())
1692 }
1693
1694 #[cfg(feature = "zarr")]
1695 #[test]
1696 fn sample_parallel() -> Result<()> {
1697 let logp = NormalLogp { dim: 100, mu: 0.1 };
1698 let settings = DiagNutsSettings {
1699 num_tune: 100,
1700 num_draws: 100,
1701 seed: 10,
1702 ..Default::default()
1703 };
1704
1705 let model = CpuModel::new(logp.clone());
1706 let store = MemoryStore::new();
1707
1708 let zarr_config = ZarrConfig::new(Arc::new(store));
1709 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1710 sampler.pause()?;
1711 sampler.pause()?;
1712 sampler.resume()?;
1714 let (ok, _) = sampler.abort()?;
1715 if let Some(err) = ok {
1716 Err(err)?;
1717 }
1718
1719 let store = MemoryStore::new();
1720 let zarr_config = ZarrConfig::new(Arc::new(store));
1721 let model = CpuModel::new(logp.clone());
1722 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1723 sampler.pause()?;
1724 if let (Some(err), _) = sampler.abort()? {
1725 Err(err)?;
1726 }
1727
1728 let store = MemoryStore::new();
1729 let zarr_config = ZarrConfig::new(Arc::new(store));
1730 let model = CpuModel::new(logp.clone());
1731 let start = Instant::now();
1732 let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1733
1734 let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1735 super::SamplerWaitResult::Trace(_) => {
1736 dbg!(start.elapsed());
1737 panic!("finished");
1738 }
1739 super::SamplerWaitResult::Timeout(sampler) => sampler,
1740 super::SamplerWaitResult::Err(_, _) => {
1741 panic!("error")
1742 }
1743 };
1744
1745 for _ in 0..30 {
1746 sampler.progress()?;
1747 }
1748
1749 match sampler.wait_timeout(Duration::from_secs(1)) {
1750 super::SamplerWaitResult::Trace(_) => {
1751 dbg!(start.elapsed());
1752 }
1753 super::SamplerWaitResult::Timeout(_) => {
1754 panic!("timeout")
1755 }
1756 super::SamplerWaitResult::Err(err, _) => Err(err)?,
1757 };
1758
1759 Ok(())
1760 }
1761
1762 #[test]
1763 fn sample_seq() {
1764 let logp = NormalLogp { dim: 10, mu: 0.1 };
1765 let math = CpuMath::new(&logp);
1766 let settings = DiagNutsSettings {
1767 num_tune: 100,
1768 num_draws: 100,
1769 ..Default::default()
1770 };
1771 let start = vec![0.2; 10];
1772
1773 let mut rng = StdRng::seed_from_u64(42);
1774
1775 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1776 let mut draws = chain.collect_vec();
1777 assert_eq!(draws.len(), 200);
1778
1779 let draw0 = draws.remove(100).unwrap();
1780 let (vals, stats) = draw0;
1781 assert_eq!(vals.len(), 10);
1782 assert_eq!(stats.chain, 1);
1783 assert_eq!(stats.draw, 100);
1784 }
1785}