1use anyhow::{Context, Result, bail};
5use itertools::Itertools;
6use nuts_storable::{HasDims, Storable, Value};
7use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
8use rayon::{ScopeFifo, ThreadPoolBuilder};
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use std::{
11 collections::HashMap,
12 fmt::Debug,
13 ops::Deref,
14 sync::{
15 Arc, Mutex,
16 mpsc::{
17 Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
18 },
19 },
20 thread::{JoinHandle, spawn},
21 time::{Duration, Instant},
22};
23
24use crate::{
25 DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
26 adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
27 chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
28 dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
29 external_adapt_strategy::{ExternalTransformAdaptation, FlowSettings},
30 mclmc::MclmcTrajectoryKind,
31 model::Model,
32 nuts::NutsOptions,
33 sampler_stats::{SamplerStats, StatsDims},
34 storage::{ChainStorage, StorageConfig, TraceStorage},
35 transform::{
36 DiagAdaptStrategy, DiagMassMatrix, ExternalTransformation, LowRankMassMatrix,
37 LowRankMassMatrixStrategy, LowRankSettings,
38 },
39};
40
41pub trait Settings:
43 private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + DeserializeOwned + 'static
44{
45 type Chain<M: Math>: Chain<M>;
46
47 fn new_chain<M: Math, R: Rng + ?Sized>(
48 &self,
49 chain: u64,
50 math: M,
51 rng: &mut R,
52 ) -> Self::Chain<M>;
53
54 fn hint_num_tune(&self) -> usize;
55 fn hint_num_draws(&self) -> usize;
56 fn num_chains(&self) -> usize;
57 fn seed(&self) -> u64;
58 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
59 fn sampler_name(&self) -> &'static str;
60 fn adaptation_name(&self) -> &'static str;
61
62 fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
63 let dims = StatsDims::from(math);
64 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
65 .into_iter()
66 .map(String::from)
67 .collect()
68 }
69
70 fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
71 <M::ExpandedVector as Storable<_>>::names(math)
72 .into_iter()
73 .map(String::from)
74 .collect()
75 }
76
77 fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
78 self.stat_names(math)
79 .into_iter()
80 .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
81 .collect()
82 }
83
84 fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
85 let dims = StatsDims::from(math);
86 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
87 }
88
89 fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
90 self.data_names(math)
91 .into_iter()
92 .map(|name| (name.clone(), self.data_type(math, &name)))
93 .collect()
94 }
95 fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
96 <M::ExpandedVector as Storable<_>>::item_type(math, name)
97 }
98
99 fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
100 self.stat_names(math)
101 .into_iter()
102 .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
103 .collect()
104 }
105
106 fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
107 let dims = StatsDims::from(math);
108 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
109 .into_iter()
110 .map(String::from)
111 .collect()
112 }
113
114 fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
115 let dims = StatsDims::from(math);
116 dims.dim_sizes()
117 }
118
119 fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
120 self.data_names(math)
121 .into_iter()
122 .map(|name| (name.clone(), self.data_dims(math, &name)))
123 .collect()
124 }
125
126 fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
127 <M::ExpandedVector as Storable<_>>::dims(math, name)
128 .into_iter()
129 .map(String::from)
130 .collect()
131 }
132
133 fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
134 let dims = StatsDims::from(math);
135 dims.coords()
136 }
137
138 fn stat_event_dims<M: Math>(&self, math: &M) -> Vec<(String, Option<String>)> {
139 let dims = StatsDims::from(math);
140 self.stat_names(math)
141 .into_iter()
142 .map(|name| {
143 let event_dim =
144 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::event_dim(
145 &dims, &name,
146 )
147 .map(String::from);
148 (name, event_dim)
149 })
150 .collect()
151 }
152}
153
154#[derive(Debug, Clone)]
155#[non_exhaustive]
156pub struct Progress {
157 pub draw: u64,
158 pub chain: u64,
159 pub diverging: bool,
160 pub tuning: bool,
161 pub step_size: f64,
162 pub num_steps: u64,
163}
164
165mod private {
166 use super::{
167 DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, FlowNutsSettings,
168 LowRankMclmcSettings, LowRankNutsSettings,
169 };
170
171 pub trait Sealed {}
172
173 impl Sealed for DiagNutsSettings {}
174
175 impl Sealed for LowRankNutsSettings {}
176
177 impl Sealed for FlowNutsSettings {}
178
179 impl Sealed for DiagMclmcSettings {}
180
181 impl Sealed for LowRankMclmcSettings {}
182
183 impl Sealed for FlowMclmcSettings {}
184}
185
186#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
188pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
189 pub num_tune: u64,
191 pub num_draws: u64,
193 pub maxdepth: u64,
196 pub mindepth: u64,
199 pub store_gradient: bool,
201 pub store_unconstrained: bool,
203 pub store_transformed: bool,
205 pub max_energy_error: f64,
208 pub store_divergences: bool,
210 pub adapt_options: A,
212 pub check_turning: bool,
213 pub target_integration_time: Option<f64>,
214 pub trajectory_kind: KineticEnergyKind,
222 pub num_chains: usize,
223 pub seed: u64,
224 pub extra_doublings: u64,
228}
229
230pub type DiagNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
231#[deprecated(since = "0.0.0", note = "Use DiagNutsSettings instead")]
233pub type DiagGradNutsSettings = DiagNutsSettings;
234pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
235pub type FlowNutsSettings = NutsSettings<FlowSettings>;
236#[deprecated(since = "0.0.0", note = "Use FlowNutsSettings instead")]
238pub type TransformedNutsSettings = FlowNutsSettings;
239
240#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
256pub struct MclmcSettings<A: Debug + Copy + Default + Serialize> {
257 pub step_size: f64,
259 pub momentum_decoherence_length: f64,
262 pub num_tune: u64,
264 pub num_draws: u64,
266 pub num_chains: usize,
268 pub seed: u64,
270 pub max_energy_error: f64,
272 pub store_unconstrained: bool,
274 pub store_gradient: bool,
276 pub store_transformed: bool,
278 pub store_divergences: bool,
280 pub adapt_options: A,
282 pub subsample_frequency: f64,
292 pub dynamic_step_size: bool,
298 pub trajectory_kind: MclmcTrajectoryKind,
302 pub trajectory_switch_fraction: f64,
307}
308
309pub type DiagMclmcSettings = MclmcSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
314pub type LowRankMclmcSettings = MclmcSettings<EuclideanAdaptOptions<LowRankSettings>>;
319pub type FlowMclmcSettings = MclmcSettings<FlowSettings>;
324#[deprecated(since = "0.0.0", note = "Use FlowMclmcSettings instead")]
326pub type TransformedMclmcSettings = FlowMclmcSettings;
327
328fn usize_hint(value: u64, field: &str) -> usize {
329 value
330 .try_into()
331 .unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
332}
333
334fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
335 adapt_options: A,
336 num_tune: u64,
337 num_chains: usize,
338 max_energy_error: f64,
339) -> MclmcSettings<A> {
340 MclmcSettings {
341 step_size: 0.5,
342 momentum_decoherence_length: 3.0,
343 num_tune,
344 num_draws: 1000,
345 num_chains,
346 seed: 0,
347 max_energy_error,
348 store_unconstrained: false,
349 store_gradient: false,
350 store_divergences: false,
351 store_transformed: false,
352 adapt_options,
353 subsample_frequency: 1.0,
354 dynamic_step_size: true,
355 trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
356 trajectory_switch_fraction: 0.3,
357 }
358}
359
360impl Default for DiagMclmcSettings {
361 fn default() -> Self {
362 let mut adapt_options = EuclideanAdaptOptions::default();
363 adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
364 default_mclmc_settings(adapt_options, 400, 6, 1000.0)
365 }
366}
367
368impl Default for LowRankMclmcSettings {
369 fn default() -> Self {
370 let mut adapt_options = EuclideanAdaptOptions::default();
371 adapt_options.early_mass_matrix_switch_freq = 20;
372 adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
373 default_mclmc_settings(adapt_options, 800, 6, 1000.0)
374 }
375}
376
377impl Default for FlowMclmcSettings {
378 fn default() -> Self {
379 default_mclmc_settings(FlowSettings::default(), 1500, 1, 20.0)
380 }
381}
382
383type DiagMclmcChain<M> = crate::mclmc::MclmcChain<
384 M,
385 ChaCha8Rng,
386 GlobalStrategy<M, DiagAdaptStrategy<M>>,
387 DiagMassMatrix<M>,
388>;
389type LowRankMclmcChain<M> = crate::mclmc::MclmcChain<
390 M,
391 ChaCha8Rng,
392 GlobalStrategy<M, LowRankMassMatrixStrategy>,
393 LowRankMassMatrix<M>,
394>;
395
396impl Settings for DiagMclmcSettings {
397 type Chain<M: Math> = DiagMclmcChain<M>;
398
399 fn new_chain<M: Math, R: Rng + ?Sized>(
400 &self,
401 chain: u64,
402 mut math: M,
403 rng: &mut R,
404 ) -> Self::Chain<M> {
405 use crate::dynamics::KineticEnergyKind;
406 use crate::mclmc::MclmcChain;
407 use crate::stepsize::StepSizeAdaptMethod;
408
409 let num_tune = self.num_tune;
410 let mut adapt_options = self.adapt_options;
411 adapt_options.step_size_settings.adapt_options.method =
412 StepSizeAdaptMethod::Fixed(self.step_size);
413 let strategy = GlobalStrategy::<M, DiagAdaptStrategy<M>>::new(
414 &mut math,
415 adapt_options,
416 num_tune,
417 chain,
418 );
419 let mass_matrix = DiagMassMatrix::new(
420 &mut math,
421 self.adapt_options.mass_matrix_options.store_mass_matrix,
422 );
423 let initial_kind = match self.trajectory_kind {
424 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
425 MclmcTrajectoryKind::Euclidean
426 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
427 };
428 let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
429 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
430 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
431 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
432 let stats_options = self.stats_options::<M>();
433 MclmcChain::new(
434 math,
435 hamiltonian,
436 strategy,
437 rng,
438 chain,
439 self.subsample_frequency,
440 self.dynamic_step_size,
441 self.trajectory_kind,
442 switch_draw,
443 self.max_energy_error,
444 stats_options,
445 )
446 }
447
448 fn hint_num_tune(&self) -> usize {
449 usize_hint(self.num_tune, "num_tune")
450 }
451
452 fn hint_num_draws(&self) -> usize {
453 usize_hint(self.num_draws, "num_draws")
454 }
455
456 fn num_chains(&self) -> usize {
457 self.num_chains
458 }
459
460 fn seed(&self) -> u64 {
461 self.seed
462 }
463
464 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
465 StatOptions {
466 adapt: GlobalStrategyStatsOptions {
467 step_size: (),
468 mass_matrix: (),
469 },
470 hamiltonian: -1,
471 point: {
472 let store_gradient = self.store_gradient;
473 let store_unconstrained = self.store_unconstrained;
474 let store_transformed = self.store_transformed;
475 TransformedPointStatsOptions {
476 store_gradient,
477 store_unconstrained,
478 store_transformed,
479 }
480 },
481 divergence: crate::dynamics::DivergenceStatsOptions {
482 store_divergences: self.store_divergences,
483 },
484 }
485 }
486
487 fn sampler_name(&self) -> &'static str {
488 "mclmc"
489 }
490
491 fn adaptation_name(&self) -> &'static str {
492 "diagonal"
493 }
494}
495
496fn default_nuts_settings<A: Debug + Copy + Default + Serialize>(
497 adapt_options: A,
498 num_tune: u64,
499 num_chains: usize,
500 max_energy_error: f64,
501) -> NutsSettings<A> {
502 NutsSettings {
503 num_tune,
504 num_draws: 1000,
505 maxdepth: 10,
506 mindepth: 0,
507 max_energy_error,
508 store_gradient: false,
509 store_unconstrained: false,
510 store_transformed: false,
511 store_divergences: false,
512 adapt_options,
513 check_turning: true,
514 seed: 0,
515 num_chains,
516 target_integration_time: None,
517 trajectory_kind: KineticEnergyKind::Euclidean,
518 extra_doublings: 0,
519 }
520}
521
522impl Settings for LowRankMclmcSettings {
523 type Chain<M: Math> = LowRankMclmcChain<M>;
524
525 fn new_chain<M: Math, R: Rng + ?Sized>(
526 &self,
527 chain: u64,
528 mut math: M,
529 rng: &mut R,
530 ) -> Self::Chain<M> {
531 use crate::dynamics::KineticEnergyKind;
532 use crate::mclmc::MclmcChain;
533 use crate::stepsize::StepSizeAdaptMethod;
534
535 let num_tune = self.num_tune;
536 let mut adapt_options = self.adapt_options;
537 adapt_options.step_size_settings.adapt_options.method =
538 StepSizeAdaptMethod::Fixed(self.step_size);
539 let strategy = GlobalStrategy::<M, LowRankMassMatrixStrategy>::new(
540 &mut math,
541 adapt_options,
542 num_tune,
543 chain,
544 );
545 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
546 let initial_kind = match self.trajectory_kind {
547 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
548 MclmcTrajectoryKind::Euclidean
549 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
550 };
551 let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
552 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
553 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
554 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
555 let stats_options = self.stats_options::<M>();
556 MclmcChain::new(
557 math,
558 hamiltonian,
559 strategy,
560 rng,
561 chain,
562 self.subsample_frequency,
563 self.dynamic_step_size,
564 self.trajectory_kind,
565 switch_draw,
566 self.max_energy_error,
567 stats_options,
568 )
569 }
570
571 fn hint_num_tune(&self) -> usize {
572 usize_hint(self.num_tune, "num_tune")
573 }
574
575 fn hint_num_draws(&self) -> usize {
576 usize_hint(self.num_draws, "num_draws")
577 }
578
579 fn num_chains(&self) -> usize {
580 self.num_chains
581 }
582
583 fn seed(&self) -> u64 {
584 self.seed
585 }
586
587 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
588 StatOptions {
589 adapt: GlobalStrategyStatsOptions {
590 step_size: (),
591 mass_matrix: (),
592 },
593 hamiltonian: -1,
594 point: {
595 let store_gradient = self.store_gradient;
596 let store_unconstrained = self.store_unconstrained;
597 let store_transformed = self.store_transformed;
598 TransformedPointStatsOptions {
599 store_gradient,
600 store_unconstrained,
601 store_transformed,
602 }
603 },
604 divergence: crate::dynamics::DivergenceStatsOptions {
605 store_divergences: self.store_divergences,
606 },
607 }
608 }
609
610 fn sampler_name(&self) -> &'static str {
611 "mclmc"
612 }
613
614 fn adaptation_name(&self) -> &'static str {
615 "low_rank"
616 }
617}
618
619impl Default for DiagNutsSettings {
620 fn default() -> Self {
621 default_nuts_settings(EuclideanAdaptOptions::default(), 400, 6, 1000.0)
622 }
623}
624
625impl Default for LowRankNutsSettings {
626 fn default() -> Self {
627 let mut vals = default_nuts_settings(EuclideanAdaptOptions::default(), 800, 6, 1000.0);
628 vals.adapt_options.mass_matrix_update_freq = 20;
629 vals
630 }
631}
632
633impl Default for FlowNutsSettings {
634 fn default() -> Self {
635 default_nuts_settings(FlowSettings::default(), 1500, 1, 20.0)
636 }
637}
638
639type DiagNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, DiagAdaptStrategy<M>>>;
640type LowRankNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
641
642fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>) -> NutsOptions {
643 NutsOptions {
644 maxdepth: settings.maxdepth,
645 mindepth: settings.mindepth,
646 store_divergences: settings.store_divergences,
647 check_turning: settings.check_turning,
648 target_integration_time: settings.target_integration_time,
649 extra_doublings: settings.extra_doublings,
650 max_energy_error: settings.max_energy_error,
651 }
652}
653
654impl Settings for LowRankNutsSettings {
655 type Chain<M: Math> = LowRankNutsChain<M>;
656
657 fn new_chain<M: Math, R: Rng + ?Sized>(
658 &self,
659 chain: u64,
660 mut math: M,
661 mut rng: &mut R,
662 ) -> Self::Chain<M> {
663 let num_tune = self.num_tune;
664 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
665 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
666 let hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
667
668 let options = nuts_options(self);
669
670 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
671
672 NutsChain::new(
673 math,
674 hamiltonian,
675 strategy,
676 options,
677 rng,
678 chain,
679 self.stats_options(),
680 )
681 }
682
683 fn hint_num_tune(&self) -> usize {
684 usize_hint(self.num_tune, "num_tune")
685 }
686
687 fn hint_num_draws(&self) -> usize {
688 usize_hint(self.num_draws, "num_draws")
689 }
690
691 fn num_chains(&self) -> usize {
692 self.num_chains
693 }
694
695 fn seed(&self) -> u64 {
696 self.seed
697 }
698
699 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
700 StatOptions {
701 adapt: GlobalStrategyStatsOptions {
702 mass_matrix: (),
703 step_size: (),
704 },
705 hamiltonian: -1,
706 point: {
707 let store_gradient = self.store_gradient;
708 let store_unconstrained = self.store_unconstrained;
709 let store_transformed = self.store_transformed;
710 TransformedPointStatsOptions {
711 store_gradient,
712 store_unconstrained,
713 store_transformed,
714 }
715 },
716 divergence: crate::dynamics::DivergenceStatsOptions {
717 store_divergences: self.store_divergences,
718 },
719 }
720 }
721
722 fn sampler_name(&self) -> &'static str {
723 "nuts"
724 }
725
726 fn adaptation_name(&self) -> &'static str {
727 "low_rank"
728 }
729}
730
731impl Settings for DiagNutsSettings {
732 type Chain<M: Math> = DiagNutsChain<M>;
733
734 fn new_chain<M: Math, R: Rng + ?Sized>(
735 &self,
736 chain: u64,
737 mut math: M,
738 mut rng: &mut R,
739 ) -> Self::Chain<M> {
740 let num_tune = self.num_tune;
741 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
742 let mass_matrix = DiagMassMatrix::new(
743 &mut math,
744 self.adapt_options.mass_matrix_options.store_mass_matrix,
745 );
746 let potential = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
747
748 let options = nuts_options(self);
749
750 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
751
752 NutsChain::new(
753 math,
754 potential,
755 strategy,
756 options,
757 rng,
758 chain,
759 self.stats_options(),
760 )
761 }
762
763 fn hint_num_tune(&self) -> usize {
764 usize_hint(self.num_tune, "num_tune")
765 }
766
767 fn hint_num_draws(&self) -> usize {
768 usize_hint(self.num_draws, "num_draws")
769 }
770
771 fn num_chains(&self) -> usize {
772 self.num_chains
773 }
774
775 fn seed(&self) -> u64 {
776 self.seed
777 }
778
779 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
780 StatOptions {
781 adapt: GlobalStrategyStatsOptions {
782 mass_matrix: (),
783 step_size: (),
784 },
785 hamiltonian: -1,
786 point: {
787 let store_gradient = self.store_gradient;
788 let store_unconstrained = self.store_unconstrained;
789 let store_transformed = self.store_transformed;
790 TransformedPointStatsOptions {
791 store_gradient,
792 store_unconstrained,
793 store_transformed,
794 }
795 },
796 divergence: crate::dynamics::DivergenceStatsOptions {
797 store_divergences: self.store_divergences,
798 },
799 }
800 }
801
802 fn sampler_name(&self) -> &'static str {
803 "nuts"
804 }
805
806 fn adaptation_name(&self) -> &'static str {
807 "diagonal"
808 }
809}
810
811impl Settings for FlowNutsSettings {
812 type Chain<M: Math> = NutsChain<M, ChaCha8Rng, ExternalTransformAdaptation>;
813
814 fn new_chain<M: Math, R: Rng + ?Sized>(
815 &self,
816 chain: u64,
817 mut math: M,
818 mut rng: &mut R,
819 ) -> Self::Chain<M> {
820 let num_tune = self.num_tune;
821
822 let strategy =
823 ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
824 let params = math
825 .new_transformation(rng, math.dim(), chain)
826 .expect("Failed to create external transformation");
827 let transform = ExternalTransformation::new(params);
828 let hamiltonian = TransformedHamiltonian::new(&mut math, transform, self.trajectory_kind);
829
830 let options = nuts_options(self);
831
832 let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
833 NutsChain::new(
834 math,
835 hamiltonian,
836 strategy,
837 options,
838 rng,
839 chain,
840 self.stats_options(),
841 )
842 }
843
844 fn hint_num_tune(&self) -> usize {
845 usize_hint(self.num_tune, "num_tune")
846 }
847
848 fn hint_num_draws(&self) -> usize {
849 usize_hint(self.num_draws, "num_draws")
850 }
851
852 fn num_chains(&self) -> usize {
853 self.num_chains
854 }
855
856 fn seed(&self) -> u64 {
857 self.seed
858 }
859
860 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
861 StatOptions {
862 adapt: (),
863 hamiltonian: (),
864 point: {
865 let store_gradient = self.store_gradient;
866 let store_unconstrained = self.store_unconstrained;
867 let store_transformed = self.store_transformed;
868 TransformedPointStatsOptions {
869 store_gradient,
870 store_unconstrained,
871 store_transformed,
872 }
873 },
874 divergence: crate::dynamics::DivergenceStatsOptions {
875 store_divergences: self.store_divergences,
876 },
877 }
878 }
879
880 fn sampler_name(&self) -> &'static str {
881 "nuts"
882 }
883
884 fn adaptation_name(&self) -> &'static str {
885 "flow"
886 }
887}
888
889impl Settings for FlowMclmcSettings {
890 type Chain<M: Math> = crate::mclmc::MclmcChain<
891 M,
892 ChaCha8Rng,
893 ExternalTransformAdaptation,
894 ExternalTransformation<M>,
895 >;
896
897 fn new_chain<M: Math, R: Rng + ?Sized>(
898 &self,
899 chain: u64,
900 mut math: M,
901 rng: &mut R,
902 ) -> Self::Chain<M> {
903 use crate::dynamics::KineticEnergyKind;
904 use crate::mclmc::MclmcChain;
905
906 let num_tune = self.num_tune;
907 let strategy =
908 ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
909 let params = math
910 .new_transformation(rng, math.dim(), chain)
911 .expect("Failed to create external transformation");
912 let transform = ExternalTransformation::new(params);
913 let initial_kind = match self.trajectory_kind {
914 MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
915 MclmcTrajectoryKind::Euclidean
916 | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
917 };
918 let mut hamiltonian = TransformedHamiltonian::new(&mut math, transform, initial_kind);
919 hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
920 let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
921 let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
922 let stats_options = self.stats_options::<M>();
923 MclmcChain::new(
924 math,
925 hamiltonian,
926 strategy,
927 rng,
928 chain,
929 self.subsample_frequency,
930 self.dynamic_step_size,
931 self.trajectory_kind,
932 switch_draw,
933 self.max_energy_error,
934 stats_options,
935 )
936 }
937
938 fn hint_num_tune(&self) -> usize {
939 usize_hint(self.num_tune, "num_tune")
940 }
941
942 fn hint_num_draws(&self) -> usize {
943 usize_hint(self.num_draws, "num_draws")
944 }
945
946 fn num_chains(&self) -> usize {
947 self.num_chains
948 }
949
950 fn seed(&self) -> u64 {
951 self.seed
952 }
953
954 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
955 StatOptions {
956 adapt: (),
957 hamiltonian: (),
958 point: {
959 let store_gradient = self.store_gradient;
960 let store_unconstrained = self.store_unconstrained;
961 let store_transformed = self.store_transformed;
962 TransformedPointStatsOptions {
963 store_gradient,
964 store_unconstrained,
965 store_transformed,
966 }
967 },
968 divergence: crate::dynamics::DivergenceStatsOptions {
969 store_divergences: self.store_divergences,
970 },
971 }
972 }
973
974 fn sampler_name(&self) -> &'static str {
975 "mclmc"
976 }
977
978 fn adaptation_name(&self) -> &'static str {
979 "flow"
980 }
981}
982
983pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
984 math: M,
985 settings: DiagNutsSettings,
986 start: &[f64],
987 draws: u64,
988 chain: u64,
989 rng: &mut R,
990) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
991 let mut sampler = settings.new_chain(chain, math, rng);
992 sampler.set_position(start)?;
993 Ok((0..draws).map(move |_| sampler.draw()))
994}
995
996#[non_exhaustive]
997#[derive(Clone, Debug)]
998pub struct ChainProgress {
999 pub finished_draws: usize,
1000 pub total_draws: usize,
1001 pub divergences: usize,
1002 pub tuning: bool,
1003 pub started: bool,
1004 pub latest_num_steps: usize,
1005 pub total_num_steps: usize,
1006 pub step_size: f64,
1007 pub runtime: Duration,
1008 pub divergent_draws: Vec<usize>,
1009}
1010
1011impl ChainProgress {
1012 fn new(total: usize) -> Self {
1013 Self {
1014 finished_draws: 0,
1015 total_draws: total,
1016 divergences: 0,
1017 tuning: true,
1018 started: false,
1019 latest_num_steps: 0,
1020 step_size: 0f64,
1021 total_num_steps: 0,
1022 runtime: Duration::ZERO,
1023 divergent_draws: Vec::new(),
1024 }
1025 }
1026
1027 fn update(&mut self, stats: &Progress, draw_duration: Duration) {
1028 if stats.diverging & !stats.tuning {
1029 self.divergences += 1;
1030 self.divergent_draws.push(self.finished_draws);
1031 }
1032 self.finished_draws += 1;
1033 self.tuning = stats.tuning;
1034
1035 self.latest_num_steps = stats.num_steps as usize;
1036 self.total_num_steps += stats.num_steps as usize;
1037 self.step_size = stats.step_size;
1038 self.runtime += draw_duration;
1039 }
1040}
1041
1042enum ChainCommand {
1043 Resume,
1044 Pause,
1045}
1046
1047struct ChainProcess<T>
1048where
1049 T: TraceStorage,
1050{
1051 stop_marker: Sender<ChainCommand>,
1052 trace: Arc<Mutex<Option<T::ChainStorage>>>,
1053 progress: Arc<Mutex<ChainProgress>>,
1054}
1055
1056impl<T: TraceStorage> ChainProcess<T> {
1057 fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
1058 let finalized_chain_traces = chains
1059 .into_iter()
1060 .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
1061 .map(|chain| chain.finalize())
1062 .collect_vec();
1063 trace.finalize(finalized_chain_traces)
1064 }
1065
1066 fn progress(&self) -> ChainProgress {
1067 self.progress.lock().expect("Poisoned lock").clone()
1068 }
1069
1070 fn resume(&self) -> Result<()> {
1071 self.stop_marker.send(ChainCommand::Resume)?;
1072 Ok(())
1073 }
1074
1075 fn pause(&self) -> Result<()> {
1076 self.stop_marker.send(ChainCommand::Pause)?;
1077 Ok(())
1078 }
1079
1080 fn start<'model, M: Model, S: Settings>(
1081 model: &'model M,
1082 chain_trace: T::ChainStorage,
1083 chain_id: u64,
1084 seed: u64,
1085 settings: &'model S,
1086 scope: &ScopeFifo<'model>,
1087 results: Sender<Result<()>>,
1088 ) -> Result<Self> {
1089 let (stop_marker_tx, stop_marker_rx) = channel();
1090
1091 let mut rng = ChaCha8Rng::seed_from_u64(seed);
1092 rng.set_stream(chain_id + 1);
1093
1094 let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
1095 let progress = Arc::new(Mutex::new(ChainProgress::new(
1096 settings.hint_num_draws() + settings.hint_num_tune(),
1097 )));
1098
1099 let trace_inner = chain_trace.clone();
1100 let progress_inner = progress.clone();
1101
1102 scope.spawn_fifo(move |_| {
1103 let chain_trace = trace_inner;
1104 let progress = progress_inner;
1105
1106 let mut sample = move || {
1107 let logp = model
1108 .math(&mut rng)
1109 .context("Failed to create model density")?;
1110 let dim = logp.dim();
1111
1112 let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
1113
1114 progress.lock().expect("Poisoned mutex").started = true;
1115
1116 let mut initval = vec![0f64; dim];
1117 let mut error = None;
1119 for _ in 0..500 {
1120 model
1121 .init_position(&mut rng, &mut initval)
1122 .context("Failed to generate a new initial position")?;
1123 if let Err(err) = sampler.set_position(&initval) {
1124 error = Some(err);
1125 continue;
1126 }
1127 error = None;
1128 break;
1129 }
1130
1131 if let Some(error) = error {
1132 return Err(error.context("All initialization points failed"));
1133 }
1134
1135 let draws = settings.hint_num_tune() + settings.hint_num_draws();
1136
1137 let mut msg = stop_marker_rx.try_recv();
1138 let mut draw = 0;
1139 loop {
1140 match msg {
1141 Err(TryRecvError::Disconnected) => {
1143 break;
1144 }
1145 Err(TryRecvError::Empty) => {}
1146 Ok(ChainCommand::Pause) => {
1147 msg = stop_marker_rx.recv().map_err(|e| e.into());
1148 continue;
1149 }
1150 Ok(ChainCommand::Resume) => {}
1151 }
1152
1153 let now = Instant::now();
1154 let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
1155
1156 let mut guard = chain_trace
1157 .lock()
1158 .expect("Could not unlock trace lock. Poisoned mutex");
1159
1160 let Some(trace_val) = guard.as_mut() else {
1161 break;
1163 };
1164 progress
1165 .lock()
1166 .expect("Poisoned mutex")
1167 .update(&info, now.elapsed());
1168
1169 let math = sampler.math();
1170 let dims = StatsDims::from(math.deref());
1171 trace_val.record_sample(
1172 settings,
1173 stats.get_all(&dims),
1174 draw_data.get_all(math.deref()),
1175 &info,
1176 )?;
1177
1178 draw += 1;
1179 if draw == draws {
1180 break;
1181 }
1182
1183 msg = stop_marker_rx.try_recv();
1184 }
1185 Ok(())
1186 };
1187
1188 let result = sample();
1189
1190 let _ = results.send(result);
1193 drop(results);
1194 });
1195
1196 Ok(Self {
1197 trace: chain_trace,
1198 stop_marker: stop_marker_tx,
1199 progress,
1200 })
1201 }
1202
1203 fn flush(&self) -> Result<()> {
1204 self.trace
1205 .lock()
1206 .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
1207 .context("Could not flush trace")?
1208 .as_mut()
1209 .map(|v| v.flush())
1210 .transpose()?;
1211 Ok(())
1212 }
1213}
1214
1215#[derive(Debug)]
1216enum SamplerCommand {
1217 Pause,
1218 Continue,
1219 Progress,
1220 Flush,
1221 Inspect,
1222}
1223
1224enum SamplerResponse<T: Send + 'static> {
1225 Ok(),
1226 Progress(Box<[ChainProgress]>),
1227 Inspect(T),
1228}
1229
1230pub enum SamplerWaitResult<F: Send + 'static> {
1231 Trace(F),
1232 Timeout(Sampler<F>),
1233 Err(anyhow::Error, Option<F>),
1234}
1235
1236pub struct Sampler<F: Send + 'static> {
1237 main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
1238 commands: SyncSender<SamplerCommand>,
1239 responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
1240 results: Receiver<Result<()>>,
1241}
1242
1243pub struct ProgressCallback {
1244 pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
1245 pub rate: Duration,
1246}
1247
1248impl<F: Send + 'static> Sampler<F> {
1249 pub fn new<M, S, C, T>(
1250 model: M,
1251 settings: S,
1252 trace_config: C,
1253 num_cores: usize,
1254 callback: Option<ProgressCallback>,
1255 ) -> Result<Self>
1256 where
1257 S: Settings,
1258 C: StorageConfig<Storage = T>,
1259 M: Model,
1260 T: TraceStorage<Finalized = F>,
1261 {
1262 let (commands_tx, commands_rx) = sync_channel(0);
1263 let (responses_tx, responses_rx) = sync_channel(0);
1264 let (results_tx, results_rx) = channel();
1265
1266 let main_thread = spawn(move || {
1267 let pool = ThreadPoolBuilder::new()
1268 .num_threads(num_cores + 1) .thread_name(|i| format!("nutpie-worker-{i}"))
1270 .build()
1271 .context("Could not start thread pool")?;
1272
1273 let settings_ref = &settings;
1274 let model_ref = &model;
1275 let mut callback = callback;
1276
1277 pool.scope_fifo(move |scope| {
1278 let results = results_tx;
1279 let mut chains = Vec::with_capacity(settings.num_chains());
1280
1281 let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
1282 rng.set_stream(0);
1283
1284 let math = model_ref
1285 .math(&mut rng)
1286 .context("Could not create model density")?;
1287 let trace = trace_config
1288 .new_trace(settings_ref, &math)
1289 .context("Could not create trace object")?;
1290 drop(math);
1291
1292 for chain_id in 0..settings.num_chains() {
1293 let chain_trace_val = trace
1294 .initialize_trace_for_chain(chain_id as u64)
1295 .context("Failed to create trace object")?;
1296 let chain = ChainProcess::start(
1297 model_ref,
1298 chain_trace_val,
1299 chain_id as u64,
1300 settings.seed(),
1301 settings_ref,
1302 scope,
1303 results.clone(),
1304 );
1305 chains.push(chain);
1306 }
1307 drop(results);
1308
1309 let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
1310 if let Some(error) = errors.into_iter().next() {
1311 let _ = ChainProcess::finalize_many(trace, chains);
1312 return Err(error).context("Could not start chains");
1313 }
1314
1315 let mut main_loop = || {
1316 let start_time = Instant::now();
1317 let mut pause_start = Instant::now();
1318 let mut pause_time = Duration::ZERO;
1319
1320 let mut progress_rate = Duration::MAX;
1321 if let Some(ProgressCallback { callback, rate }) = &mut callback {
1322 let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
1323 callback(start_time.elapsed(), progress.into());
1324 progress_rate = *rate;
1325 }
1326 let mut last_progress = Instant::now();
1327 let mut is_paused = false;
1328
1329 loop {
1330 let timeout = progress_rate.checked_sub(last_progress.elapsed());
1331 let timeout = timeout.unwrap_or_else(|| {
1332 if let Some(ProgressCallback { callback, .. }) = &mut callback {
1333 let progress =
1334 chains.iter().map(|chain| chain.progress()).collect_vec();
1335 let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
1336 if is_paused {
1337 elapsed = elapsed.saturating_sub(pause_start.elapsed());
1338 }
1339 callback(elapsed, progress.into());
1340 }
1341 last_progress = Instant::now();
1342 progress_rate
1343 });
1344
1345 match commands_rx.recv_timeout(timeout) {
1347 Ok(SamplerCommand::Pause) => {
1348 for chain in chains.iter() {
1349 let _ = chain.pause();
1352 }
1353 if !is_paused {
1354 pause_start = Instant::now();
1355 }
1356 is_paused = true;
1357 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1358 anyhow::anyhow!(
1359 "Could not send pause response to controller thread: {e}"
1360 )
1361 })?;
1362 }
1363 Ok(SamplerCommand::Continue) => {
1364 for chain in chains.iter() {
1365 let _ = chain.resume();
1368 }
1369 pause_time += pause_start.elapsed();
1370 is_paused = false;
1371 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1372 anyhow::anyhow!(
1373 "Could not send continue response to controller thread: {e}"
1374 )
1375 })?;
1376 }
1377 Ok(SamplerCommand::Progress) => {
1378 let progress =
1379 chains.iter().map(|chain| chain.progress()).collect_vec();
1380 responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
1381 anyhow::anyhow!(
1382 "Could not send progress response to controller thread: {e}"
1383 )
1384 })?;
1385 }
1386 Ok(SamplerCommand::Inspect) => {
1387 let traces = chains
1388 .iter()
1389 .filter_map(|chain| {
1390 chain
1391 .trace
1392 .lock()
1393 .expect("Poisoned lock")
1394 .as_ref()
1395 .map(|v| v.inspect())
1396 })
1397 .collect_vec();
1398 let finalized_trace = trace.inspect(traces)?;
1399 responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
1400 anyhow::anyhow!(
1401 "Could not send inspect response to controller thread: {e}"
1402 )
1403 })?;
1404 }
1405 Ok(SamplerCommand::Flush) => {
1406 for chain in chains.iter() {
1407 chain.flush()?;
1408 }
1409 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1410 anyhow::anyhow!(
1411 "Could not send flush response to controller thread: {e}"
1412 )
1413 })?;
1414 }
1415 Err(RecvTimeoutError::Timeout) => {}
1416 Err(RecvTimeoutError::Disconnected) => {
1417 if let Some(ProgressCallback { callback, .. }) = &mut callback {
1418 let progress =
1419 chains.iter().map(|chain| chain.progress()).collect_vec();
1420 let mut elapsed =
1421 start_time.elapsed().saturating_sub(pause_time);
1422 if is_paused {
1423 elapsed = elapsed.saturating_sub(pause_start.elapsed());
1424 }
1425 callback(elapsed, progress.into());
1426 }
1427 return Ok(());
1428 }
1429 };
1430 }
1431 };
1432 let result: Result<()> = main_loop();
1433 let output = ChainProcess::finalize_many(trace, chains)?;
1435
1436 result?;
1437 Ok(output)
1438 })
1439 });
1440
1441 Ok(Self {
1442 main_thread,
1443 commands: commands_tx,
1444 responses: responses_rx,
1445 results: results_rx,
1446 })
1447 }
1448
1449 pub fn pause(&mut self) -> Result<()> {
1450 self.commands
1451 .send(SamplerCommand::Pause)
1452 .context("Could not send pause command to controller thread")?;
1453 let response = self
1454 .responses
1455 .recv()
1456 .context("Could not recieve pause response from controller thread")?;
1457 let SamplerResponse::Ok() = response else {
1458 bail!("Got invalid response from sample controller thread");
1459 };
1460 Ok(())
1461 }
1462
1463 pub fn resume(&mut self) -> Result<()> {
1464 self.commands.send(SamplerCommand::Continue)?;
1465 let response = self.responses.recv()?;
1466 let SamplerResponse::Ok() = response else {
1467 bail!("Got invalid response from sample controller thread");
1468 };
1469 Ok(())
1470 }
1471
1472 pub fn flush(&mut self) -> Result<()> {
1473 self.commands.send(SamplerCommand::Flush)?;
1474 let response = self
1475 .responses
1476 .recv()
1477 .context("Could not recieve flush response from controller thread")?;
1478 let SamplerResponse::Ok() = response else {
1479 bail!("Got invalid response from sample controller thread");
1480 };
1481 Ok(())
1482 }
1483
1484 pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
1485 self.commands.send(SamplerCommand::Inspect)?;
1486 let response = self
1487 .responses
1488 .recv()
1489 .context("Could not recieve inspect response from controller thread")?;
1490 let SamplerResponse::Inspect(trace) = response else {
1491 bail!("Got invalid response from sample controller thread");
1492 };
1493 Ok(trace)
1494 }
1495
1496 pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
1497 drop(self.commands);
1498 let result = self.main_thread.join();
1499 match result {
1500 Err(payload) => std::panic::resume_unwind(payload),
1501 Ok(Ok(val)) => Ok(val),
1502 Ok(Err(err)) => Err(err),
1503 }
1504 }
1505
1506 pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
1507 let start = Instant::now();
1508 let mut remaining = Some(timeout);
1509 while remaining.is_some() {
1510 match self.results.recv_timeout(timeout) {
1511 Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
1512 Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
1513 Err(RecvTimeoutError::Disconnected) => match self.abort() {
1514 Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
1515 Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
1516 Err(err) => return SamplerWaitResult::Err(err, None),
1517 },
1518 Err(RecvTimeoutError::Timeout) => break,
1519 }
1520 }
1521 SamplerWaitResult::Timeout(self)
1522 }
1523
1524 pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1525 self.commands.send(SamplerCommand::Progress)?;
1526 let response = self.responses.recv()?;
1527 let SamplerResponse::Progress(progress) = response else {
1528 bail!("Got invalid response from sample controller thread");
1529 };
1530 Ok(progress)
1531 }
1532}
1533
1534#[cfg(test)]
1535pub mod test_logps {
1536
1537 use std::collections::HashMap;
1538
1539 use crate::math::{CpuLogpFunc, LogpError};
1540 #[cfg(feature = "zarr")]
1541 use crate::{Model, math::CpuMath};
1542 use anyhow::Result;
1543 use nuts_storable::HasDims;
1544 #[cfg(feature = "zarr")]
1545 use rand::Rng;
1546 use thiserror::Error;
1547
1548 #[derive(Clone, Debug)]
1549 pub struct NormalLogp {
1550 pub dim: usize,
1551 pub mu: f64,
1552 }
1553
1554 #[derive(Error, Debug)]
1555 pub enum NormalLogpError {}
1556
1557 impl LogpError for NormalLogpError {
1558 fn is_recoverable(&self) -> bool {
1559 false
1560 }
1561 }
1562
1563 impl HasDims for &NormalLogp {
1564 fn dim_sizes(&self) -> HashMap<String, u64> {
1565 vec![
1566 ("unconstrained_parameter".to_string(), self.dim as u64),
1567 ("dim".to_string(), self.dim as u64),
1568 ]
1569 .into_iter()
1570 .collect()
1571 }
1572 }
1573
1574 impl CpuLogpFunc for &NormalLogp {
1575 type LogpError = NormalLogpError;
1576 type FlowParameters = ();
1577 type ExpandedVector = Vec<f64>;
1578
1579 fn dim(&self) -> usize {
1580 self.dim
1581 }
1582
1583 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
1584 let n = position.len();
1585 assert!(gradient.len() == n);
1586
1587 let mut logp = 0f64;
1588 for (p, g) in position.iter().zip(gradient.iter_mut()) {
1589 let val = self.mu - p;
1590 logp -= val * val / 2.;
1591 *g = val;
1592 }
1593
1594 Ok(logp)
1595 }
1596
1597 fn expand_vector<R>(
1598 &mut self,
1599 _rng: &mut R,
1600 array: &[f64],
1601 ) -> std::result::Result<Self::ExpandedVector, crate::math::CpuMathError>
1602 where
1603 R: rand::Rng + ?Sized,
1604 {
1605 Ok(array.to_vec())
1606 }
1607
1608 fn inv_transform_normalize(
1609 &mut self,
1610 _params: &Self::FlowParameters,
1611 _untransformed_position: &[f64],
1612 _untransofrmed_gradient: &[f64],
1613 _transformed_position: &mut [f64],
1614 _transformed_gradient: &mut [f64],
1615 ) -> std::result::Result<f64, Self::LogpError> {
1616 unimplemented!()
1617 }
1618
1619 fn init_from_untransformed_position(
1620 &mut self,
1621 _params: &Self::FlowParameters,
1622 _untransformed_position: &[f64],
1623 _untransformed_gradient: &mut [f64],
1624 _transformed_position: &mut [f64],
1625 _transformed_gradient: &mut [f64],
1626 ) -> std::result::Result<(f64, f64), Self::LogpError> {
1627 unimplemented!()
1628 }
1629
1630 fn init_from_transformed_position(
1631 &mut self,
1632 _params: &Self::FlowParameters,
1633 _untransformed_position: &mut [f64],
1634 _untransformed_gradient: &mut [f64],
1635 _transformed_position: &[f64],
1636 _transformed_gradient: &mut [f64],
1637 ) -> std::result::Result<(f64, f64), Self::LogpError> {
1638 unimplemented!()
1639 }
1640
1641 fn update_transformation<'b, R: rand::Rng + ?Sized>(
1642 &'b mut self,
1643 _rng: &mut R,
1644 _untransformed_positions: impl Iterator<Item = &'b [f64]>,
1645 _untransformed_gradients: impl Iterator<Item = &'b [f64]>,
1646 _untransformed_logp: impl Iterator<Item = &'b f64>,
1647 _params: &'b mut Self::FlowParameters,
1648 ) -> std::result::Result<(), Self::LogpError> {
1649 unimplemented!()
1650 }
1651
1652 fn init_transformation<R: rand::Rng + ?Sized>(
1653 &mut self,
1654 _rng: &mut R,
1655 _untransformed_position: &[f64],
1656 _untransfogmed_gradient: &[f64],
1657 _chain: u64,
1658 ) -> std::result::Result<Self::FlowParameters, Self::LogpError> {
1659 unimplemented!()
1660 }
1661
1662 fn transformation_id(
1663 &self,
1664 _params: &Self::FlowParameters,
1665 ) -> std::result::Result<i64, Self::LogpError> {
1666 unimplemented!()
1667 }
1668 }
1669
1670 #[cfg(feature = "zarr")]
1671 pub struct CpuModel<F> {
1672 logp: F,
1673 }
1674
1675 #[cfg(feature = "zarr")]
1676 impl<F> CpuModel<F> {
1677 pub fn new(logp: F) -> Self {
1678 Self { logp }
1679 }
1680 }
1681
1682 #[cfg(feature = "zarr")]
1683 impl<F> Model for CpuModel<F>
1684 where
1685 F: Send + Sync + 'static,
1686 for<'a> &'a F: CpuLogpFunc,
1687 {
1688 type Math<'model> = CpuMath<&'model F>;
1689
1690 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1691 Ok(CpuMath::new(&self.logp))
1692 }
1693
1694 fn init_position<R: rand::prelude::Rng + ?Sized>(
1695 &self,
1696 _rng: &mut R,
1697 position: &mut [f64],
1698 ) -> Result<()> {
1699 position.iter_mut().for_each(|x| *x = 0.);
1700 Ok(())
1701 }
1702 }
1703}
1704
1705#[cfg(test)]
1706mod tests {
1707 use super::test_logps::NormalLogp;
1708 use crate::{
1709 Chain, math::CpuMath, sample_sequentially, sampler::DiagMclmcSettings,
1710 sampler::DiagNutsSettings, sampler::LowRankMclmcSettings, sampler::LowRankNutsSettings,
1711 sampler::Settings,
1712 };
1713
1714 #[cfg(feature = "zarr")]
1715 use super::test_logps::CpuModel;
1716
1717 use anyhow::Result;
1718 use itertools::Itertools;
1719 use pretty_assertions::assert_eq;
1720 use rand::{SeedableRng, rngs::StdRng};
1721
1722 #[cfg(feature = "zarr")]
1723 use std::{
1724 sync::Arc,
1725 time::{Duration, Instant},
1726 };
1727
1728 #[cfg(feature = "zarr")]
1729 use crate::{Sampler, ZarrConfig};
1730
1731 #[cfg(feature = "zarr")]
1732 use zarrs::storage::store::MemoryStore;
1733
1734 fn assert_settings_smoke<S: Settings>(settings: S) -> Result<()> {
1735 let logp = NormalLogp { dim: 4, mu: 0.1 };
1736 let math = CpuMath::new(&logp);
1737 let mut rng = StdRng::seed_from_u64(42);
1738
1739 let stat_names = settings.stat_names(&math);
1740 let stat_types = settings.stat_types(&math);
1741 assert!(!stat_names.is_empty());
1742 assert_eq!(stat_names.len(), stat_types.len());
1743
1744 let mut chain = settings.new_chain(0, math, &mut rng);
1745 chain.set_position(&vec![0.2; 4])?;
1746 let (_draw, _info) = chain.draw()?;
1747 Ok(())
1748 }
1749
1750 #[test]
1751 fn all_settings_smoke() -> Result<()> {
1752 assert_settings_smoke(DiagNutsSettings {
1753 num_tune: 10,
1754 num_draws: 10,
1755 ..Default::default()
1756 })?;
1757 assert_settings_smoke(LowRankNutsSettings {
1758 num_tune: 10,
1759 num_draws: 10,
1760 ..Default::default()
1761 })?;
1762 assert_settings_smoke(DiagMclmcSettings {
1763 num_tune: 10,
1764 num_draws: 10,
1765 ..Default::default()
1766 })?;
1767 assert_settings_smoke(LowRankMclmcSettings {
1768 num_tune: 10,
1769 num_draws: 10,
1770 ..Default::default()
1771 })?;
1772 Ok(())
1773 }
1774
1775 #[test]
1776 fn sample_chain() -> Result<()> {
1777 let logp = NormalLogp { dim: 10, mu: 0.1 };
1778 let math = CpuMath::new(&logp);
1779 let settings = DiagNutsSettings {
1780 num_tune: 100,
1781 num_draws: 100,
1782 ..Default::default()
1783 };
1784 let start = vec![0.2; 10];
1785
1786 let mut rng = StdRng::seed_from_u64(42);
1787
1788 let mut chain = settings.new_chain(0, math, &mut rng);
1789
1790 let (_draw, info) = chain.draw()?;
1791 assert!(info.tuning);
1792 assert_eq!(info.draw, 0);
1793
1794 let math = CpuMath::new(&logp);
1795 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1796 let mut draws = chain.collect_vec();
1797 assert_eq!(draws.len(), 200);
1798
1799 let draw0 = draws.remove(100).unwrap();
1800 let (vals, stats) = draw0;
1801 assert_eq!(vals.len(), 10);
1802 assert_eq!(stats.chain, 1);
1803 assert_eq!(stats.draw, 100);
1804 Ok(())
1805 }
1806
1807 #[cfg(feature = "zarr")]
1808 #[test]
1809 fn sample_parallel() -> Result<()> {
1810 let logp = NormalLogp { dim: 100, mu: 0.1 };
1811 let settings = DiagNutsSettings {
1812 num_tune: 100,
1813 num_draws: 100,
1814 seed: 10,
1815 ..Default::default()
1816 };
1817
1818 let model = CpuModel::new(logp.clone());
1819 let store = MemoryStore::new();
1820
1821 let zarr_config = ZarrConfig::new(Arc::new(store));
1822 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1823 sampler.pause()?;
1824 sampler.pause()?;
1825 sampler.resume()?;
1827 let (ok, _) = sampler.abort()?;
1828 if let Some(err) = ok {
1829 Err(err)?;
1830 }
1831
1832 let store = MemoryStore::new();
1833 let zarr_config = ZarrConfig::new(Arc::new(store));
1834 let model = CpuModel::new(logp.clone());
1835 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1836 sampler.pause()?;
1837 if let (Some(err), _) = sampler.abort()? {
1838 Err(err)?;
1839 }
1840
1841 let store = MemoryStore::new();
1842 let zarr_config = ZarrConfig::new(Arc::new(store));
1843 let model = CpuModel::new(logp.clone());
1844 let start = Instant::now();
1845 let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1846
1847 let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1848 super::SamplerWaitResult::Trace(_) => {
1849 dbg!(start.elapsed());
1850 panic!("finished");
1851 }
1852 super::SamplerWaitResult::Timeout(sampler) => sampler,
1853 super::SamplerWaitResult::Err(_, _) => {
1854 panic!("error")
1855 }
1856 };
1857
1858 for _ in 0..30 {
1859 sampler.progress()?;
1860 }
1861
1862 match sampler.wait_timeout(Duration::from_secs(1)) {
1863 super::SamplerWaitResult::Trace(_) => {
1864 dbg!(start.elapsed());
1865 }
1866 super::SamplerWaitResult::Timeout(_) => {
1867 panic!("timeout")
1868 }
1869 super::SamplerWaitResult::Err(err, _) => Err(err)?,
1870 };
1871
1872 Ok(())
1873 }
1874
1875 #[test]
1876 fn sample_seq() {
1877 let logp = NormalLogp { dim: 10, mu: 0.1 };
1878 let math = CpuMath::new(&logp);
1879 let settings = DiagNutsSettings {
1880 num_tune: 100,
1881 num_draws: 100,
1882 ..Default::default()
1883 };
1884 let start = vec![0.2; 10];
1885
1886 let mut rng = StdRng::seed_from_u64(42);
1887
1888 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1889 let mut draws = chain.collect_vec();
1890 assert_eq!(draws.len(), 200);
1891
1892 let draw0 = draws.remove(100).unwrap();
1893 let (vals, stats) = draw0;
1894 assert_eq!(vals.len(), 10);
1895 assert_eq!(stats.chain, 1);
1896 assert_eq!(stats.draw, 100);
1897 }
1898}