Skip to main content

nuts_rs/
sampler.rs

1//! High-level sampler entry points: `Settings` presets, the parallel `Sampler`,
2//! and `sample_sequentially` for running one or many chains.
3
4use anyhow::Result;
5use nuts_storable::{HasDims, Storable, Value};
6use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::{collections::HashMap, fmt::Debug, time::Duration};
9
10#[cfg(feature = "parallel")]
11use anyhow::{Context, bail};
12#[cfg(feature = "parallel")]
13use itertools::Itertools;
14#[cfg(feature = "parallel")]
15use std::ops::Deref;
16
17#[cfg(feature = "parallel")]
18use rayon::{ScopeFifo, ThreadPoolBuilder};
19#[cfg(feature = "parallel")]
20use std::{
21    sync::{
22        Arc, Mutex,
23        mpsc::{
24            Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
25        },
26    },
27    thread::{JoinHandle, spawn},
28    time::Instant,
29};
30
31use crate::{
32    DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
33    adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
34    chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
35    dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
36    external_adapt_strategy::{ExternalTransformAdaptation, FlowSettings},
37    mclmc::MclmcTrajectoryKind,
38    nuts::NutsOptions,
39    sampler_stats::{SamplerStats, StatsDims},
40    transform::{
41        DiagAdaptStrategy, DiagMassMatrix, ExternalTransformation, LowRankMassMatrix,
42        LowRankMassMatrixStrategy, LowRankSettings,
43    },
44};
45
46#[cfg(feature = "parallel")]
47use crate::{
48    model::Model,
49    storage::{ChainStorage, StorageConfig, TraceStorage},
50};
51
52/// All sampler configurations implement this trait
53pub trait Settings:
54    private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + DeserializeOwned + 'static
55{
56    type Chain<M: Math>: Chain<M>;
57
58    fn new_chain<M: Math, R: Rng + ?Sized>(
59        &self,
60        chain: u64,
61        math: M,
62        rng: &mut R,
63    ) -> Self::Chain<M>;
64
65    fn hint_num_tune(&self) -> usize;
66    fn hint_num_draws(&self) -> usize;
67    fn num_chains(&self) -> usize;
68    fn seed(&self) -> u64;
69    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
70    fn sampler_name(&self) -> &'static str;
71    fn adaptation_name(&self) -> &'static str;
72
73    fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
74        let dims = StatsDims::from(math);
75        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
76            .into_iter()
77            .map(String::from)
78            .collect()
79    }
80
81    fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
82        <M::ExpandedVector as Storable<_>>::names(math)
83            .into_iter()
84            .map(String::from)
85            .collect()
86    }
87
88    fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
89        self.stat_names(math)
90            .into_iter()
91            .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
92            .collect()
93    }
94
95    fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
96        let dims = StatsDims::from(math);
97        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
98    }
99
100    fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
101        self.data_names(math)
102            .into_iter()
103            .map(|name| (name.clone(), self.data_type(math, &name)))
104            .collect()
105    }
106    fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
107        <M::ExpandedVector as Storable<_>>::item_type(math, name)
108    }
109
110    fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
111        self.stat_names(math)
112            .into_iter()
113            .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
114            .collect()
115    }
116
117    fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
118        let dims = StatsDims::from(math);
119        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
120            .into_iter()
121            .map(String::from)
122            .collect()
123    }
124
125    fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
126        let dims = StatsDims::from(math);
127        dims.dim_sizes()
128    }
129
130    fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
131        self.data_names(math)
132            .into_iter()
133            .map(|name| (name.clone(), self.data_dims(math, &name)))
134            .collect()
135    }
136
137    fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
138        <M::ExpandedVector as Storable<_>>::dims(math, name)
139            .into_iter()
140            .map(String::from)
141            .collect()
142    }
143
144    fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
145        let dims = StatsDims::from(math);
146        dims.coords()
147    }
148
149    fn stat_event_dims<M: Math>(&self, math: &M) -> Vec<(String, Option<String>)> {
150        let dims = StatsDims::from(math);
151        self.stat_names(math)
152            .into_iter()
153            .map(|name| {
154                let event_dim =
155                    <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::event_dim(
156                        &dims, &name,
157                    )
158                    .map(String::from);
159                (name, event_dim)
160            })
161            .collect()
162    }
163}
164
165#[derive(Debug, Clone)]
166#[non_exhaustive]
167pub struct Progress {
168    pub draw: u64,
169    pub chain: u64,
170    pub diverging: bool,
171    pub tuning: bool,
172    pub step_size: f64,
173    pub num_steps: u64,
174}
175
176mod private {
177    use super::{
178        DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, FlowNutsSettings,
179        LowRankMclmcSettings, LowRankNutsSettings,
180    };
181
182    pub trait Sealed {}
183
184    impl Sealed for DiagNutsSettings {}
185
186    impl Sealed for LowRankNutsSettings {}
187
188    impl Sealed for FlowNutsSettings {}
189
190    impl Sealed for DiagMclmcSettings {}
191
192    impl Sealed for LowRankMclmcSettings {}
193
194    impl Sealed for FlowMclmcSettings {}
195}
196
197/// Settings for the NUTS sampler
198#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
199pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
200    /// The number of tuning steps, where we fit the step size and geometry.
201    pub num_tune: u64,
202    /// The number of draws after tuning
203    pub num_draws: u64,
204    /// The maximum tree depth during sampling. The number of leapfrog steps
205    /// is smaller than 2 ^ maxdepth.
206    pub maxdepth: u64,
207    /// The minimum tree depth during sampling. The number of leapfrog steps
208    /// is larger than 2 ^ mindepth.
209    pub mindepth: u64,
210    /// Store the gradient in the SampleStats
211    pub store_gradient: bool,
212    /// Store each unconstrained parameter vector in the sampler stats
213    pub store_unconstrained: bool,
214    /// Store the transformed gradient and value in the sampler stats
215    pub store_transformed: bool,
216    /// If the energy error is larger than this threshold we treat the leapfrog
217    /// step as a divergence.
218    pub max_energy_error: f64,
219    /// Store detailed information about each divergence in the sampler stats
220    pub store_divergences: bool,
221    /// Settings for geometry adaptation.
222    pub adapt_options: A,
223    pub check_turning: bool,
224    pub target_integration_time: Option<f64>,
225    /// Selects the kinetic-energy form and the corresponding integrator.
226    ///
227    /// - [`KineticEnergyKind::Euclidean`]: standard leapfrog (default for most settings).
228    /// - [`KineticEnergyKind::ExactNormal`]: geodesic leapfrog exact for a standard-normal
229    ///   potential.
230    /// - [`KineticEnergyKind::Microcanonical`]: isokinetic ESH-dynamics leapfrog (microcanonical
231    ///   HMC); momentum is constrained to the unit sphere.
232    pub trajectory_kind: KineticEnergyKind,
233    pub num_chains: usize,
234    pub seed: u64,
235    /// Number of extra doublings to perform after reaching maxdepth. This can
236    /// be used to increase the effective sample size at the cost of more
237    /// expensive sampling.
238    pub extra_doublings: u64,
239}
240
241pub type DiagNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
242/// Backwards-compatible alias for [`DiagNutsSettings`].
243#[deprecated(since = "0.0.0", note = "Use DiagNutsSettings instead")]
244pub type DiagGradNutsSettings = DiagNutsSettings;
245pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
246pub type FlowNutsSettings = NutsSettings<FlowSettings>;
247/// Backwards-compatible alias for [`FlowNutsSettings`].
248#[deprecated(since = "0.0.0", note = "Use FlowNutsSettings instead")]
249pub type TransformedNutsSettings = FlowNutsSettings;
250
251/// Settings for the unadjusted Microcanonical Langevin Monte Carlo (MCLMC) sampler.
252///
253/// > ⚠️ **Experimental — use with caution**: The MCLMC sampler and all of its
254/// > variants are highly experimental. They have not been thoroughly validated
255/// > and may **not return correct posteriors**. The API, defaults, and
256/// > adaptation behaviour are all subject to breaking changes at any time.
257/// > Do not use these samplers in production or for results you rely on.
258///
259/// Step size `ε` and momentum decoherence length `L` are **constants** — no
260/// adaptation of those is performed yet. The geometry is adapted during
261/// warmup using the sampler-specific adaptation strategy, while the step size
262/// remains fixed.
263///
264/// Use the type aliases [`DiagMclmcSettings`], [`LowRankMclmcSettings`], and
265/// [`FlowMclmcSettings`] for concrete configurations.
266#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
267pub struct MclmcSettings<A: Debug + Copy + Default + Serialize> {
268    /// Step size ε for the ESH leapfrog integrator.
269    pub step_size: f64,
270    /// Momentum decoherence length L (controls partial momentum refresh rate).
271    /// Set to `f64::INFINITY` to disable momentum refresh entirely.
272    pub momentum_decoherence_length: f64,
273    /// Number of warmup draws.
274    pub num_tune: u64,
275    /// Number of sampling draws after warmup.
276    pub num_draws: u64,
277    /// Number of parallel chains.
278    pub num_chains: usize,
279    /// RNG seed.
280    pub seed: u64,
281    /// Maximum energy error before a step is flagged as a divergence.
282    pub max_energy_error: f64,
283    /// Store each unconstrained parameter vector in the sampler stats.
284    pub store_unconstrained: bool,
285    /// Store the gradient in the sampler stats.
286    pub store_gradient: bool,
287    /// Store the transformed gradient and value in the sampler stats
288    pub store_transformed: bool,
289    /// Store detailed information about each divergence in the sampler stats
290    pub store_divergences: bool,
291    /// Geometry adaptation options (step-size fields are ignored for Euclidean settings).
292    pub adapt_options: A,
293    /// Number of leapfrog steps per draw as a fraction of `L / ε`.
294    ///
295    /// The number of leapfrog steps between collector calls is:
296    /// `round(subsample_frequency * L / ε).max(1)`
297    ///
298    /// - `1.0` (default) — one sample per full trajectory (at the final step).
299    /// - `0.0` — every leapfrog step.
300    /// - Values in between space samples as a fraction of the decoherence
301    ///   length, so the interval scales naturally when `L` or `ε` changes.
302    pub subsample_frequency: f64,
303    /// When `true`, use the tree-structured step size retry on divergence:
304    /// halve the step size factor and try 2 steps before doubling back.
305    /// `log_weight` will include `log(step_size)` to correct for the varying
306    /// sampling density. When `false`, divergences are recorded immediately
307    /// without any retry and `log_weight = -energy_change`.
308    pub dynamic_step_size: bool,
309    /// Selects which leapfrog integrator and partial-momentum-refresh style
310    /// to use.  See [`MclmcTrajectoryKind`] for the available options.
311    /// Default: [`MclmcTrajectoryKind::Microcanonical`] (original MCLMC).
312    pub trajectory_kind: MclmcTrajectoryKind,
313    /// Fraction of `num_tune` draws at which the trajectory is switched from
314    /// Euclidean to Microcanonical when
315    /// `trajectory_kind == MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical`.
316    /// Ignored for other trajectory kinds.  Default: `0.3`.
317    pub trajectory_switch_fraction: f64,
318}
319
320/// MCLMC settings with a diagonal mass matrix adaptation.
321///
322/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
323/// > of the returned posteriors has not been verified. May change at any time.
324pub type DiagMclmcSettings = MclmcSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
325/// MCLMC settings with a low-rank mass matrix adaptation.
326///
327/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
328/// > of the returned posteriors has not been verified. May change at any time.
329pub type LowRankMclmcSettings = MclmcSettings<EuclideanAdaptOptions<LowRankSettings>>;
330/// MCLMC settings with a learned flow transformation.
331///
332/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
333/// > of the returned posteriors has not been verified. May change at any time.
334pub type FlowMclmcSettings = MclmcSettings<FlowSettings>;
335/// Backwards-compatible alias for [`FlowMclmcSettings`].
336#[deprecated(since = "0.0.0", note = "Use FlowMclmcSettings instead")]
337pub type TransformedMclmcSettings = FlowMclmcSettings;
338
339fn usize_hint(value: u64, field: &str) -> usize {
340    value
341        .try_into()
342        .unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
343}
344
345fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
346    adapt_options: A,
347    num_tune: u64,
348    num_chains: usize,
349    max_energy_error: f64,
350) -> MclmcSettings<A> {
351    MclmcSettings {
352        step_size: 0.5,
353        momentum_decoherence_length: 3.0,
354        num_tune,
355        num_draws: 1000,
356        num_chains,
357        seed: 0,
358        max_energy_error,
359        store_unconstrained: false,
360        store_gradient: false,
361        store_divergences: false,
362        store_transformed: false,
363        adapt_options,
364        subsample_frequency: 1.0,
365        dynamic_step_size: true,
366        trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
367        trajectory_switch_fraction: 0.3,
368    }
369}
370
371impl Default for DiagMclmcSettings {
372    fn default() -> Self {
373        let mut adapt_options = EuclideanAdaptOptions::default();
374        adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
375        default_mclmc_settings(adapt_options, 400, 6, 1000.0)
376    }
377}
378
379impl Default for LowRankMclmcSettings {
380    fn default() -> Self {
381        let mut adapt_options = EuclideanAdaptOptions::default();
382        adapt_options.early_mass_matrix_switch_freq = 20;
383        adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
384        default_mclmc_settings(adapt_options, 800, 6, 1000.0)
385    }
386}
387
388impl Default for FlowMclmcSettings {
389    fn default() -> Self {
390        default_mclmc_settings(FlowSettings::default(), 1500, 1, 20.0)
391    }
392}
393
394type DiagMclmcChain<M> = crate::mclmc::MclmcChain<
395    M,
396    ChaCha8Rng,
397    GlobalStrategy<M, DiagAdaptStrategy<M>>,
398    DiagMassMatrix<M>,
399>;
400type LowRankMclmcChain<M> = crate::mclmc::MclmcChain<
401    M,
402    ChaCha8Rng,
403    GlobalStrategy<M, LowRankMassMatrixStrategy>,
404    LowRankMassMatrix<M>,
405>;
406
407impl Settings for DiagMclmcSettings {
408    type Chain<M: Math> = DiagMclmcChain<M>;
409
410    fn new_chain<M: Math, R: Rng + ?Sized>(
411        &self,
412        chain: u64,
413        mut math: M,
414        rng: &mut R,
415    ) -> Self::Chain<M> {
416        use crate::dynamics::KineticEnergyKind;
417        use crate::mclmc::MclmcChain;
418        use crate::stepsize::StepSizeAdaptMethod;
419
420        let num_tune = self.num_tune;
421        let mut adapt_options = self.adapt_options;
422        adapt_options.step_size_settings.adapt_options.method =
423            StepSizeAdaptMethod::Fixed(self.step_size);
424        let strategy = GlobalStrategy::<M, DiagAdaptStrategy<M>>::new(
425            &mut math,
426            adapt_options,
427            num_tune,
428            chain,
429        );
430        let mass_matrix = DiagMassMatrix::new(
431            &mut math,
432            self.adapt_options.mass_matrix_options.store_mass_matrix,
433        );
434        let initial_kind = match self.trajectory_kind {
435            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
436            MclmcTrajectoryKind::Euclidean
437            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
438        };
439        let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
440        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
441        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
442        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
443        let stats_options = self.stats_options::<M>();
444        MclmcChain::new(
445            math,
446            hamiltonian,
447            strategy,
448            rng,
449            chain,
450            self.subsample_frequency,
451            self.dynamic_step_size,
452            self.trajectory_kind,
453            switch_draw,
454            self.max_energy_error,
455            stats_options,
456        )
457    }
458
459    fn hint_num_tune(&self) -> usize {
460        usize_hint(self.num_tune, "num_tune")
461    }
462
463    fn hint_num_draws(&self) -> usize {
464        usize_hint(self.num_draws, "num_draws")
465    }
466
467    fn num_chains(&self) -> usize {
468        self.num_chains
469    }
470
471    fn seed(&self) -> u64 {
472        self.seed
473    }
474
475    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
476        StatOptions {
477            adapt: GlobalStrategyStatsOptions {
478                step_size: (),
479                mass_matrix: (),
480            },
481            hamiltonian: -1,
482            point: {
483                let store_gradient = self.store_gradient;
484                let store_unconstrained = self.store_unconstrained;
485                let store_transformed = self.store_transformed;
486                TransformedPointStatsOptions {
487                    store_gradient,
488                    store_unconstrained,
489                    store_transformed,
490                }
491            },
492            divergence: crate::dynamics::DivergenceStatsOptions {
493                store_divergences: self.store_divergences,
494            },
495        }
496    }
497
498    fn sampler_name(&self) -> &'static str {
499        "mclmc"
500    }
501
502    fn adaptation_name(&self) -> &'static str {
503        "diagonal"
504    }
505}
506
507fn default_nuts_settings<A: Debug + Copy + Default + Serialize>(
508    adapt_options: A,
509    num_tune: u64,
510    num_chains: usize,
511    max_energy_error: f64,
512) -> NutsSettings<A> {
513    NutsSettings {
514        num_tune,
515        num_draws: 1000,
516        maxdepth: 10,
517        mindepth: 0,
518        max_energy_error,
519        store_gradient: false,
520        store_unconstrained: false,
521        store_transformed: false,
522        store_divergences: false,
523        adapt_options,
524        check_turning: true,
525        seed: 0,
526        num_chains,
527        target_integration_time: None,
528        trajectory_kind: KineticEnergyKind::Euclidean,
529        extra_doublings: 0,
530    }
531}
532
533impl Settings for LowRankMclmcSettings {
534    type Chain<M: Math> = LowRankMclmcChain<M>;
535
536    fn new_chain<M: Math, R: Rng + ?Sized>(
537        &self,
538        chain: u64,
539        mut math: M,
540        rng: &mut R,
541    ) -> Self::Chain<M> {
542        use crate::dynamics::KineticEnergyKind;
543        use crate::mclmc::MclmcChain;
544        use crate::stepsize::StepSizeAdaptMethod;
545
546        let num_tune = self.num_tune;
547        let mut adapt_options = self.adapt_options;
548        adapt_options.step_size_settings.adapt_options.method =
549            StepSizeAdaptMethod::Fixed(self.step_size);
550        let strategy = GlobalStrategy::<M, LowRankMassMatrixStrategy>::new(
551            &mut math,
552            adapt_options,
553            num_tune,
554            chain,
555        );
556        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
557        let initial_kind = match self.trajectory_kind {
558            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
559            MclmcTrajectoryKind::Euclidean
560            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
561        };
562        let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
563        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
564        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
565        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
566        let stats_options = self.stats_options::<M>();
567        MclmcChain::new(
568            math,
569            hamiltonian,
570            strategy,
571            rng,
572            chain,
573            self.subsample_frequency,
574            self.dynamic_step_size,
575            self.trajectory_kind,
576            switch_draw,
577            self.max_energy_error,
578            stats_options,
579        )
580    }
581
582    fn hint_num_tune(&self) -> usize {
583        usize_hint(self.num_tune, "num_tune")
584    }
585
586    fn hint_num_draws(&self) -> usize {
587        usize_hint(self.num_draws, "num_draws")
588    }
589
590    fn num_chains(&self) -> usize {
591        self.num_chains
592    }
593
594    fn seed(&self) -> u64 {
595        self.seed
596    }
597
598    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
599        StatOptions {
600            adapt: GlobalStrategyStatsOptions {
601                step_size: (),
602                mass_matrix: (),
603            },
604            hamiltonian: -1,
605            point: {
606                let store_gradient = self.store_gradient;
607                let store_unconstrained = self.store_unconstrained;
608                let store_transformed = self.store_transformed;
609                TransformedPointStatsOptions {
610                    store_gradient,
611                    store_unconstrained,
612                    store_transformed,
613                }
614            },
615            divergence: crate::dynamics::DivergenceStatsOptions {
616                store_divergences: self.store_divergences,
617            },
618        }
619    }
620
621    fn sampler_name(&self) -> &'static str {
622        "mclmc"
623    }
624
625    fn adaptation_name(&self) -> &'static str {
626        "low_rank"
627    }
628}
629
630impl Default for DiagNutsSettings {
631    fn default() -> Self {
632        default_nuts_settings(EuclideanAdaptOptions::default(), 400, 6, 1000.0)
633    }
634}
635
636impl Default for LowRankNutsSettings {
637    fn default() -> Self {
638        let mut vals = default_nuts_settings(EuclideanAdaptOptions::default(), 800, 6, 1000.0);
639        vals.adapt_options.mass_matrix_update_freq = 20;
640        vals
641    }
642}
643
644impl Default for FlowNutsSettings {
645    fn default() -> Self {
646        default_nuts_settings(FlowSettings::default(), 1500, 1, 20.0)
647    }
648}
649
650type DiagNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, DiagAdaptStrategy<M>>>;
651type LowRankNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
652
653fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>) -> NutsOptions {
654    NutsOptions {
655        maxdepth: settings.maxdepth,
656        mindepth: settings.mindepth,
657        store_divergences: settings.store_divergences,
658        check_turning: settings.check_turning,
659        target_integration_time: settings.target_integration_time,
660        extra_doublings: settings.extra_doublings,
661        max_energy_error: settings.max_energy_error,
662    }
663}
664
665impl Settings for LowRankNutsSettings {
666    type Chain<M: Math> = LowRankNutsChain<M>;
667
668    fn new_chain<M: Math, R: Rng + ?Sized>(
669        &self,
670        chain: u64,
671        mut math: M,
672        mut rng: &mut R,
673    ) -> Self::Chain<M> {
674        let num_tune = self.num_tune;
675        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
676        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
677        let hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
678
679        let options = nuts_options(self);
680
681        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
682
683        NutsChain::new(
684            math,
685            hamiltonian,
686            strategy,
687            options,
688            rng,
689            chain,
690            self.stats_options(),
691        )
692    }
693
694    fn hint_num_tune(&self) -> usize {
695        usize_hint(self.num_tune, "num_tune")
696    }
697
698    fn hint_num_draws(&self) -> usize {
699        usize_hint(self.num_draws, "num_draws")
700    }
701
702    fn num_chains(&self) -> usize {
703        self.num_chains
704    }
705
706    fn seed(&self) -> u64 {
707        self.seed
708    }
709
710    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
711        StatOptions {
712            adapt: GlobalStrategyStatsOptions {
713                mass_matrix: (),
714                step_size: (),
715            },
716            hamiltonian: -1,
717            point: {
718                let store_gradient = self.store_gradient;
719                let store_unconstrained = self.store_unconstrained;
720                let store_transformed = self.store_transformed;
721                TransformedPointStatsOptions {
722                    store_gradient,
723                    store_unconstrained,
724                    store_transformed,
725                }
726            },
727            divergence: crate::dynamics::DivergenceStatsOptions {
728                store_divergences: self.store_divergences,
729            },
730        }
731    }
732
733    fn sampler_name(&self) -> &'static str {
734        "nuts"
735    }
736
737    fn adaptation_name(&self) -> &'static str {
738        "low_rank"
739    }
740}
741
742impl Settings for DiagNutsSettings {
743    type Chain<M: Math> = DiagNutsChain<M>;
744
745    fn new_chain<M: Math, R: Rng + ?Sized>(
746        &self,
747        chain: u64,
748        mut math: M,
749        mut rng: &mut R,
750    ) -> Self::Chain<M> {
751        let num_tune = self.num_tune;
752        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
753        let mass_matrix = DiagMassMatrix::new(
754            &mut math,
755            self.adapt_options.mass_matrix_options.store_mass_matrix,
756        );
757        let potential = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
758
759        let options = nuts_options(self);
760
761        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
762
763        NutsChain::new(
764            math,
765            potential,
766            strategy,
767            options,
768            rng,
769            chain,
770            self.stats_options(),
771        )
772    }
773
774    fn hint_num_tune(&self) -> usize {
775        usize_hint(self.num_tune, "num_tune")
776    }
777
778    fn hint_num_draws(&self) -> usize {
779        usize_hint(self.num_draws, "num_draws")
780    }
781
782    fn num_chains(&self) -> usize {
783        self.num_chains
784    }
785
786    fn seed(&self) -> u64 {
787        self.seed
788    }
789
790    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
791        StatOptions {
792            adapt: GlobalStrategyStatsOptions {
793                mass_matrix: (),
794                step_size: (),
795            },
796            hamiltonian: -1,
797            point: {
798                let store_gradient = self.store_gradient;
799                let store_unconstrained = self.store_unconstrained;
800                let store_transformed = self.store_transformed;
801                TransformedPointStatsOptions {
802                    store_gradient,
803                    store_unconstrained,
804                    store_transformed,
805                }
806            },
807            divergence: crate::dynamics::DivergenceStatsOptions {
808                store_divergences: self.store_divergences,
809            },
810        }
811    }
812
813    fn sampler_name(&self) -> &'static str {
814        "nuts"
815    }
816
817    fn adaptation_name(&self) -> &'static str {
818        "diagonal"
819    }
820}
821
822impl Settings for FlowNutsSettings {
823    type Chain<M: Math> = NutsChain<M, ChaCha8Rng, ExternalTransformAdaptation>;
824
825    fn new_chain<M: Math, R: Rng + ?Sized>(
826        &self,
827        chain: u64,
828        mut math: M,
829        mut rng: &mut R,
830    ) -> Self::Chain<M> {
831        let num_tune = self.num_tune;
832
833        let strategy =
834            ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
835        let params = math
836            .new_transformation(rng, math.dim(), chain)
837            .expect("Failed to create external transformation");
838        let transform = ExternalTransformation::new(params);
839        let hamiltonian = TransformedHamiltonian::new(&mut math, transform, self.trajectory_kind);
840
841        let options = nuts_options(self);
842
843        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
844        NutsChain::new(
845            math,
846            hamiltonian,
847            strategy,
848            options,
849            rng,
850            chain,
851            self.stats_options(),
852        )
853    }
854
855    fn hint_num_tune(&self) -> usize {
856        usize_hint(self.num_tune, "num_tune")
857    }
858
859    fn hint_num_draws(&self) -> usize {
860        usize_hint(self.num_draws, "num_draws")
861    }
862
863    fn num_chains(&self) -> usize {
864        self.num_chains
865    }
866
867    fn seed(&self) -> u64 {
868        self.seed
869    }
870
871    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
872        StatOptions {
873            adapt: (),
874            hamiltonian: (),
875            point: {
876                let store_gradient = self.store_gradient;
877                let store_unconstrained = self.store_unconstrained;
878                let store_transformed = self.store_transformed;
879                TransformedPointStatsOptions {
880                    store_gradient,
881                    store_unconstrained,
882                    store_transformed,
883                }
884            },
885            divergence: crate::dynamics::DivergenceStatsOptions {
886                store_divergences: self.store_divergences,
887            },
888        }
889    }
890
891    fn sampler_name(&self) -> &'static str {
892        "nuts"
893    }
894
895    fn adaptation_name(&self) -> &'static str {
896        "flow"
897    }
898}
899
900impl Settings for FlowMclmcSettings {
901    type Chain<M: Math> = crate::mclmc::MclmcChain<
902        M,
903        ChaCha8Rng,
904        ExternalTransformAdaptation,
905        ExternalTransformation<M>,
906    >;
907
908    fn new_chain<M: Math, R: Rng + ?Sized>(
909        &self,
910        chain: u64,
911        mut math: M,
912        rng: &mut R,
913    ) -> Self::Chain<M> {
914        use crate::dynamics::KineticEnergyKind;
915        use crate::mclmc::MclmcChain;
916
917        let num_tune = self.num_tune;
918        let strategy =
919            ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
920        let params = math
921            .new_transformation(rng, math.dim(), chain)
922            .expect("Failed to create external transformation");
923        let transform = ExternalTransformation::new(params);
924        let initial_kind = match self.trajectory_kind {
925            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
926            MclmcTrajectoryKind::Euclidean
927            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
928        };
929        let mut hamiltonian = TransformedHamiltonian::new(&mut math, transform, initial_kind);
930        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
931        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
932        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
933        let stats_options = self.stats_options::<M>();
934        MclmcChain::new(
935            math,
936            hamiltonian,
937            strategy,
938            rng,
939            chain,
940            self.subsample_frequency,
941            self.dynamic_step_size,
942            self.trajectory_kind,
943            switch_draw,
944            self.max_energy_error,
945            stats_options,
946        )
947    }
948
949    fn hint_num_tune(&self) -> usize {
950        usize_hint(self.num_tune, "num_tune")
951    }
952
953    fn hint_num_draws(&self) -> usize {
954        usize_hint(self.num_draws, "num_draws")
955    }
956
957    fn num_chains(&self) -> usize {
958        self.num_chains
959    }
960
961    fn seed(&self) -> u64 {
962        self.seed
963    }
964
965    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
966        StatOptions {
967            adapt: (),
968            hamiltonian: (),
969            point: {
970                let store_gradient = self.store_gradient;
971                let store_unconstrained = self.store_unconstrained;
972                let store_transformed = self.store_transformed;
973                TransformedPointStatsOptions {
974                    store_gradient,
975                    store_unconstrained,
976                    store_transformed,
977                }
978            },
979            divergence: crate::dynamics::DivergenceStatsOptions {
980                store_divergences: self.store_divergences,
981            },
982        }
983    }
984
985    fn sampler_name(&self) -> &'static str {
986        "mclmc"
987    }
988
989    fn adaptation_name(&self) -> &'static str {
990        "flow"
991    }
992}
993
994pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
995    math: M,
996    settings: DiagNutsSettings,
997    start: &[f64],
998    draws: u64,
999    chain: u64,
1000    rng: &mut R,
1001) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
1002    let mut sampler = settings.new_chain(chain, math, rng);
1003    sampler.set_position(start)?;
1004    Ok((0..draws).map(move |_| sampler.draw()))
1005}
1006
1007#[non_exhaustive]
1008#[derive(Clone, Debug)]
1009pub struct ChainProgress {
1010    pub finished_draws: usize,
1011    pub total_draws: usize,
1012    pub divergences: usize,
1013    pub tuning: bool,
1014    pub started: bool,
1015    pub latest_num_steps: usize,
1016    pub total_num_steps: usize,
1017    pub step_size: f64,
1018    pub runtime: Duration,
1019    pub divergent_draws: Vec<usize>,
1020}
1021
1022impl ChainProgress {
1023    fn new(total: usize) -> Self {
1024        Self {
1025            finished_draws: 0,
1026            total_draws: total,
1027            divergences: 0,
1028            tuning: true,
1029            started: false,
1030            latest_num_steps: 0,
1031            step_size: 0f64,
1032            total_num_steps: 0,
1033            runtime: Duration::ZERO,
1034            divergent_draws: Vec::new(),
1035        }
1036    }
1037
1038    fn update(&mut self, stats: &Progress, draw_duration: Duration) {
1039        if stats.diverging & !stats.tuning {
1040            self.divergences += 1;
1041            self.divergent_draws.push(self.finished_draws);
1042        }
1043        self.finished_draws += 1;
1044        self.tuning = stats.tuning;
1045
1046        self.latest_num_steps = stats.num_steps as usize;
1047        self.total_num_steps += stats.num_steps as usize;
1048        self.step_size = stats.step_size;
1049        self.runtime += draw_duration;
1050    }
1051}
1052
1053#[cfg(feature = "parallel")]
1054enum ChainCommand {
1055    Resume,
1056    Pause,
1057}
1058
1059#[cfg(feature = "parallel")]
1060struct ChainProcess<T>
1061where
1062    T: TraceStorage,
1063{
1064    stop_marker: Sender<ChainCommand>,
1065    trace: Arc<Mutex<Option<T::ChainStorage>>>,
1066    progress: Arc<Mutex<ChainProgress>>,
1067}
1068
1069#[cfg(feature = "parallel")]
1070impl<T: TraceStorage> ChainProcess<T> {
1071    fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
1072        let finalized_chain_traces = chains
1073            .into_iter()
1074            .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
1075            .map(|chain| chain.finalize())
1076            .collect_vec();
1077        trace.finalize(finalized_chain_traces)
1078    }
1079
1080    fn progress(&self) -> ChainProgress {
1081        self.progress.lock().expect("Poisoned lock").clone()
1082    }
1083
1084    fn resume(&self) -> Result<()> {
1085        self.stop_marker.send(ChainCommand::Resume)?;
1086        Ok(())
1087    }
1088
1089    fn pause(&self) -> Result<()> {
1090        self.stop_marker.send(ChainCommand::Pause)?;
1091        Ok(())
1092    }
1093
1094    fn start<'model, M: Model, S: Settings>(
1095        model: &'model M,
1096        chain_trace: T::ChainStorage,
1097        chain_id: u64,
1098        seed: u64,
1099        settings: &'model S,
1100        scope: &ScopeFifo<'model>,
1101        results: Sender<Result<()>>,
1102    ) -> Result<Self> {
1103        let (stop_marker_tx, stop_marker_rx) = channel();
1104
1105        let mut rng = ChaCha8Rng::seed_from_u64(seed);
1106        rng.set_stream(chain_id + 1);
1107
1108        let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
1109        let progress = Arc::new(Mutex::new(ChainProgress::new(
1110            settings.hint_num_draws() + settings.hint_num_tune(),
1111        )));
1112
1113        let trace_inner = chain_trace.clone();
1114        let progress_inner = progress.clone();
1115
1116        scope.spawn_fifo(move |_| {
1117            let chain_trace = trace_inner;
1118            let progress = progress_inner;
1119
1120            let mut sample = move || {
1121                let logp = model
1122                    .math(&mut rng)
1123                    .context("Failed to create model density")?;
1124                let dim = logp.dim();
1125
1126                let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
1127
1128                progress.lock().expect("Poisoned mutex").started = true;
1129
1130                let mut initval = vec![0f64; dim];
1131                // TODO maxtries
1132                let mut error = None;
1133                for _ in 0..500 {
1134                    model
1135                        .init_position(&mut rng, &mut initval)
1136                        .context("Failed to generate a new initial position")?;
1137                    if let Err(err) = sampler.set_position(&initval) {
1138                        error = Some(err);
1139                        continue;
1140                    }
1141                    error = None;
1142                    break;
1143                }
1144
1145                if let Some(error) = error {
1146                    return Err(error.context("All initialization points failed"));
1147                }
1148
1149                let draws = settings.hint_num_tune() + settings.hint_num_draws();
1150
1151                let mut msg = stop_marker_rx.try_recv();
1152                let mut draw = 0;
1153                loop {
1154                    match msg {
1155                        // The remote end is dead
1156                        Err(TryRecvError::Disconnected) => {
1157                            break;
1158                        }
1159                        Err(TryRecvError::Empty) => {}
1160                        Ok(ChainCommand::Pause) => {
1161                            msg = stop_marker_rx.recv().map_err(|e| e.into());
1162                            continue;
1163                        }
1164                        Ok(ChainCommand::Resume) => {}
1165                    }
1166
1167                    let now = Instant::now();
1168                    let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
1169
1170                    let mut guard = chain_trace
1171                        .lock()
1172                        .expect("Could not unlock trace lock. Poisoned mutex");
1173
1174                    let Some(trace_val) = guard.as_mut() else {
1175                        // The trace was removed by controller thread. We can stop sampling
1176                        break;
1177                    };
1178                    progress
1179                        .lock()
1180                        .expect("Poisoned mutex")
1181                        .update(&info, now.elapsed());
1182
1183                    let math = sampler.math();
1184                    let dims = StatsDims::from(math.deref());
1185                    trace_val.record_sample(
1186                        settings,
1187                        stats.get_all(&dims),
1188                        draw_data.get_all(math.deref()),
1189                        &info,
1190                    )?;
1191
1192                    draw += 1;
1193                    if draw == draws {
1194                        break;
1195                    }
1196
1197                    msg = stop_marker_rx.try_recv();
1198                }
1199                Ok(())
1200            };
1201
1202            let result = sample();
1203
1204            // We intentionally ignore errors here, because this means some other
1205            // chain already failed, and should have reported the error.
1206            let _ = results.send(result);
1207            drop(results);
1208        });
1209
1210        Ok(Self {
1211            trace: chain_trace,
1212            stop_marker: stop_marker_tx,
1213            progress,
1214        })
1215    }
1216
1217    fn flush(&self) -> Result<()> {
1218        self.trace
1219            .lock()
1220            .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
1221            .context("Could not flush trace")?
1222            .as_mut()
1223            .map(|v| v.flush())
1224            .transpose()?;
1225        Ok(())
1226    }
1227}
1228
1229#[cfg(feature = "parallel")]
1230#[derive(Debug)]
1231enum SamplerCommand {
1232    Pause,
1233    Continue,
1234    Progress,
1235    Flush,
1236    Inspect,
1237}
1238
1239#[cfg(feature = "parallel")]
1240enum SamplerResponse<T: Send + 'static> {
1241    Ok(),
1242    Progress(Box<[ChainProgress]>),
1243    Inspect(T),
1244}
1245
1246#[cfg(feature = "parallel")]
1247pub enum SamplerWaitResult<F: Send + 'static> {
1248    Trace(F),
1249    Timeout(Sampler<F>),
1250    Err(anyhow::Error, Option<F>),
1251}
1252
1253#[cfg(feature = "parallel")]
1254pub struct Sampler<F: Send + 'static> {
1255    main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
1256    commands: SyncSender<SamplerCommand>,
1257    responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
1258    results: Receiver<Result<()>>,
1259}
1260
1261#[cfg(feature = "parallel")]
1262pub struct ProgressCallback {
1263    pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
1264    pub rate: Duration,
1265}
1266
1267#[cfg(feature = "parallel")]
1268impl<F: Send + 'static> Sampler<F> {
1269    pub fn new<M, S, C, T>(
1270        model: M,
1271        settings: S,
1272        trace_config: C,
1273        num_cores: usize,
1274        callback: Option<ProgressCallback>,
1275    ) -> Result<Self>
1276    where
1277        S: Settings,
1278        C: StorageConfig<Storage = T>,
1279        M: Model,
1280        T: TraceStorage<Finalized = F>,
1281    {
1282        let (commands_tx, commands_rx) = sync_channel(0);
1283        let (responses_tx, responses_rx) = sync_channel(0);
1284        let (results_tx, results_rx) = channel();
1285
1286        let main_thread = spawn(move || {
1287            let pool = ThreadPoolBuilder::new()
1288                .num_threads(num_cores + 1) // One more thread because the controller also uses one
1289                .thread_name(|i| format!("nutpie-worker-{i}"))
1290                .build()
1291                .context("Could not start thread pool")?;
1292
1293            let settings_ref = &settings;
1294            let model_ref = &model;
1295            let mut callback = callback;
1296
1297            pool.scope_fifo(move |scope| {
1298                let results = results_tx;
1299                let mut chains = Vec::with_capacity(settings.num_chains());
1300
1301                let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
1302                rng.set_stream(0);
1303
1304                let math = model_ref
1305                    .math(&mut rng)
1306                    .context("Could not create model density")?;
1307                let trace = trace_config
1308                    .new_trace(settings_ref, &math)
1309                    .context("Could not create trace object")?;
1310                drop(math);
1311
1312                for chain_id in 0..settings.num_chains() {
1313                    let chain_trace_val = trace
1314                        .initialize_trace_for_chain(chain_id as u64)
1315                        .context("Failed to create trace object")?;
1316                    let chain = ChainProcess::start(
1317                        model_ref,
1318                        chain_trace_val,
1319                        chain_id as u64,
1320                        settings.seed(),
1321                        settings_ref,
1322                        scope,
1323                        results.clone(),
1324                    );
1325                    chains.push(chain);
1326                }
1327                drop(results);
1328
1329                let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
1330                if let Some(error) = errors.into_iter().next() {
1331                    let _ = ChainProcess::finalize_many(trace, chains);
1332                    return Err(error).context("Could not start chains");
1333                }
1334
1335                let mut main_loop = || {
1336                    let start_time = Instant::now();
1337                    let mut pause_start = Instant::now();
1338                    let mut pause_time = Duration::ZERO;
1339
1340                    let mut progress_rate = Duration::MAX;
1341                    if let Some(ProgressCallback { callback, rate }) = &mut callback {
1342                        let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
1343                        callback(start_time.elapsed(), progress.into());
1344                        progress_rate = *rate;
1345                    }
1346                    let mut last_progress = Instant::now();
1347                    let mut is_paused = false;
1348
1349                    loop {
1350                        let timeout = progress_rate.checked_sub(last_progress.elapsed());
1351                        let timeout = timeout.unwrap_or_else(|| {
1352                            if let Some(ProgressCallback { callback, .. }) = &mut callback {
1353                                let progress =
1354                                    chains.iter().map(|chain| chain.progress()).collect_vec();
1355                                let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
1356                                if is_paused {
1357                                    elapsed = elapsed.saturating_sub(pause_start.elapsed());
1358                                }
1359                                callback(elapsed, progress.into());
1360                            }
1361                            last_progress = Instant::now();
1362                            progress_rate
1363                        });
1364
1365                        // TODO return when all chains are done
1366                        match commands_rx.recv_timeout(timeout) {
1367                            Ok(SamplerCommand::Pause) => {
1368                                for chain in chains.iter() {
1369                                    // This failes if the thread is done.
1370                                    // We just want to ignore those threads.
1371                                    let _ = chain.pause();
1372                                }
1373                                if !is_paused {
1374                                    pause_start = Instant::now();
1375                                }
1376                                is_paused = true;
1377                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1378                                    anyhow::anyhow!(
1379                                        "Could not send pause response to controller thread: {e}"
1380                                    )
1381                                })?;
1382                            }
1383                            Ok(SamplerCommand::Continue) => {
1384                                for chain in chains.iter() {
1385                                    // This failes if the thread is done.
1386                                    // We just want to ignore those threads.
1387                                    let _ = chain.resume();
1388                                }
1389                                pause_time += pause_start.elapsed();
1390                                is_paused = false;
1391                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1392                                    anyhow::anyhow!(
1393                                        "Could not send continue response to controller thread: {e}"
1394                                    )
1395                                })?;
1396                            }
1397                            Ok(SamplerCommand::Progress) => {
1398                                let progress =
1399                                    chains.iter().map(|chain| chain.progress()).collect_vec();
1400                                responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
1401                                    anyhow::anyhow!(
1402                                        "Could not send progress response to controller thread: {e}"
1403                                    )
1404                                })?;
1405                            }
1406                            Ok(SamplerCommand::Inspect) => {
1407                                let traces = chains
1408                                    .iter()
1409                                    .filter_map(|chain| {
1410                                        chain
1411                                            .trace
1412                                            .lock()
1413                                            .expect("Poisoned lock")
1414                                            .as_ref()
1415                                            .map(|v| v.inspect())
1416                                    })
1417                                    .collect_vec();
1418                                let finalized_trace = trace.inspect(traces)?;
1419                                responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
1420                                    anyhow::anyhow!(
1421                                        "Could not send inspect response to controller thread: {e}"
1422                                    )
1423                                })?;
1424                            }
1425                            Ok(SamplerCommand::Flush) => {
1426                                for chain in chains.iter() {
1427                                    chain.flush()?;
1428                                }
1429                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1430                                    anyhow::anyhow!(
1431                                        "Could not send flush response to controller thread: {e}"
1432                                    )
1433                                })?;
1434                            }
1435                            Err(RecvTimeoutError::Timeout) => {}
1436                            Err(RecvTimeoutError::Disconnected) => {
1437                                if let Some(ProgressCallback { callback, .. }) = &mut callback {
1438                                    let progress =
1439                                        chains.iter().map(|chain| chain.progress()).collect_vec();
1440                                    let mut elapsed =
1441                                        start_time.elapsed().saturating_sub(pause_time);
1442                                    if is_paused {
1443                                        elapsed = elapsed.saturating_sub(pause_start.elapsed());
1444                                    }
1445                                    callback(elapsed, progress.into());
1446                                }
1447                                return Ok(());
1448                            }
1449                        };
1450                    }
1451                };
1452                let result: Result<()> = main_loop();
1453                // Run finalization even if something failed
1454                let output = ChainProcess::finalize_many(trace, chains)?;
1455
1456                result?;
1457                Ok(output)
1458            })
1459        });
1460
1461        Ok(Self {
1462            main_thread,
1463            commands: commands_tx,
1464            responses: responses_rx,
1465            results: results_rx,
1466        })
1467    }
1468
1469    pub fn pause(&mut self) -> Result<()> {
1470        self.commands
1471            .send(SamplerCommand::Pause)
1472            .context("Could not send pause command to controller thread")?;
1473        let response = self
1474            .responses
1475            .recv()
1476            .context("Could not recieve pause response from controller thread")?;
1477        let SamplerResponse::Ok() = response else {
1478            bail!("Got invalid response from sample controller thread");
1479        };
1480        Ok(())
1481    }
1482
1483    pub fn resume(&mut self) -> Result<()> {
1484        self.commands.send(SamplerCommand::Continue)?;
1485        let response = self.responses.recv()?;
1486        let SamplerResponse::Ok() = response else {
1487            bail!("Got invalid response from sample controller thread");
1488        };
1489        Ok(())
1490    }
1491
1492    pub fn flush(&mut self) -> Result<()> {
1493        self.commands.send(SamplerCommand::Flush)?;
1494        let response = self
1495            .responses
1496            .recv()
1497            .context("Could not recieve flush response from controller thread")?;
1498        let SamplerResponse::Ok() = response else {
1499            bail!("Got invalid response from sample controller thread");
1500        };
1501        Ok(())
1502    }
1503
1504    pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
1505        self.commands.send(SamplerCommand::Inspect)?;
1506        let response = self
1507            .responses
1508            .recv()
1509            .context("Could not recieve inspect response from controller thread")?;
1510        let SamplerResponse::Inspect(trace) = response else {
1511            bail!("Got invalid response from sample controller thread");
1512        };
1513        Ok(trace)
1514    }
1515
1516    pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
1517        drop(self.commands);
1518        let result = self.main_thread.join();
1519        match result {
1520            Err(payload) => std::panic::resume_unwind(payload),
1521            Ok(Ok(val)) => Ok(val),
1522            Ok(Err(err)) => Err(err),
1523        }
1524    }
1525
1526    pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
1527        let start = Instant::now();
1528        let mut remaining = Some(timeout);
1529        while remaining.is_some() {
1530            match self.results.recv_timeout(timeout) {
1531                Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
1532                Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
1533                Err(RecvTimeoutError::Disconnected) => match self.abort() {
1534                    Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
1535                    Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
1536                    Err(err) => return SamplerWaitResult::Err(err, None),
1537                },
1538                Err(RecvTimeoutError::Timeout) => break,
1539            }
1540        }
1541        SamplerWaitResult::Timeout(self)
1542    }
1543
1544    pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1545        self.commands.send(SamplerCommand::Progress)?;
1546        let response = self.responses.recv()?;
1547        let SamplerResponse::Progress(progress) = response else {
1548            bail!("Got invalid response from sample controller thread");
1549        };
1550        Ok(progress)
1551    }
1552}
1553
1554#[cfg(test)]
1555pub mod test_logps {
1556    use crate::{Model, math::CpuLogpFunc, math::CpuMath};
1557    use anyhow::Result;
1558    use rand::Rng;
1559
1560    pub struct CpuModel<F> {
1561        logp: F,
1562    }
1563
1564    impl<F> CpuModel<F> {
1565        pub fn new(logp: F) -> Self {
1566            Self { logp }
1567        }
1568    }
1569
1570    impl<F> Model for CpuModel<F>
1571    where
1572        F: Send + Sync + 'static,
1573        for<'a> &'a F: CpuLogpFunc,
1574    {
1575        type Math<'model> = CpuMath<&'model F>;
1576
1577        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1578            Ok(CpuMath::new(&self.logp))
1579        }
1580
1581        fn init_position<R: rand::prelude::Rng + ?Sized>(
1582            &self,
1583            _rng: &mut R,
1584            position: &mut [f64],
1585        ) -> Result<()> {
1586            position.iter_mut().for_each(|x| *x = 0.);
1587            Ok(())
1588        }
1589    }
1590}
1591
1592#[cfg(test)]
1593mod tests {
1594    use crate::math::test_logps::NormalLogp;
1595    use crate::{
1596        Chain, math::CpuMath, sample_sequentially, sampler::DiagMclmcSettings,
1597        sampler::DiagNutsSettings, sampler::LowRankMclmcSettings, sampler::LowRankNutsSettings,
1598        sampler::Settings,
1599    };
1600
1601    #[cfg(feature = "zarr")]
1602    use super::test_logps::CpuModel;
1603
1604    use anyhow::Result;
1605    use itertools::Itertools;
1606    use pretty_assertions::assert_eq;
1607    use rand::{SeedableRng, rngs::StdRng};
1608
1609    #[cfg(feature = "zarr")]
1610    use std::{
1611        sync::Arc,
1612        time::{Duration, Instant},
1613    };
1614
1615    #[cfg(feature = "zarr")]
1616    use crate::{Sampler, ZarrConfig};
1617
1618    #[cfg(feature = "zarr")]
1619    use zarrs::storage::store::MemoryStore;
1620
1621    fn assert_settings_smoke<S: Settings>(settings: S) -> Result<()> {
1622        let logp = NormalLogp { dim: 4, mu: 0.1 };
1623        let math = CpuMath::new(&logp);
1624        let mut rng = StdRng::seed_from_u64(42);
1625
1626        let stat_names = settings.stat_names(&math);
1627        let stat_types = settings.stat_types(&math);
1628        assert!(!stat_names.is_empty());
1629        assert_eq!(stat_names.len(), stat_types.len());
1630
1631        let mut chain = settings.new_chain(0, math, &mut rng);
1632        chain.set_position(&vec![0.2; 4])?;
1633        let (_draw, _info) = chain.draw()?;
1634        Ok(())
1635    }
1636
1637    #[test]
1638    fn all_settings_smoke() -> Result<()> {
1639        assert_settings_smoke(DiagNutsSettings {
1640            num_tune: 10,
1641            num_draws: 10,
1642            ..Default::default()
1643        })?;
1644        assert_settings_smoke(LowRankNutsSettings {
1645            num_tune: 10,
1646            num_draws: 10,
1647            ..Default::default()
1648        })?;
1649        assert_settings_smoke(DiagMclmcSettings {
1650            num_tune: 10,
1651            num_draws: 10,
1652            ..Default::default()
1653        })?;
1654        assert_settings_smoke(LowRankMclmcSettings {
1655            num_tune: 10,
1656            num_draws: 10,
1657            ..Default::default()
1658        })?;
1659        Ok(())
1660    }
1661
1662    #[test]
1663    fn sample_chain() -> Result<()> {
1664        let logp = NormalLogp { dim: 10, mu: 0.1 };
1665        let math = CpuMath::new(&logp);
1666        let settings = DiagNutsSettings {
1667            num_tune: 100,
1668            num_draws: 100,
1669            ..Default::default()
1670        };
1671        let start = vec![0.2; 10];
1672
1673        let mut rng = StdRng::seed_from_u64(42);
1674
1675        let mut chain = settings.new_chain(0, math, &mut rng);
1676
1677        let (_draw, info) = chain.draw()?;
1678        assert!(info.tuning);
1679        assert_eq!(info.draw, 0);
1680
1681        let math = CpuMath::new(&logp);
1682        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1683        let mut draws = chain.collect_vec();
1684        assert_eq!(draws.len(), 200);
1685
1686        let draw0 = draws.remove(100).unwrap();
1687        let (vals, stats) = draw0;
1688        assert_eq!(vals.len(), 10);
1689        assert_eq!(stats.chain, 1);
1690        assert_eq!(stats.draw, 100);
1691        Ok(())
1692    }
1693
1694    #[cfg(feature = "zarr")]
1695    #[test]
1696    fn sample_parallel() -> Result<()> {
1697        let logp = NormalLogp { dim: 100, mu: 0.1 };
1698        let settings = DiagNutsSettings {
1699            num_tune: 100,
1700            num_draws: 100,
1701            seed: 10,
1702            ..Default::default()
1703        };
1704
1705        let model = CpuModel::new(logp.clone());
1706        let store = MemoryStore::new();
1707
1708        let zarr_config = ZarrConfig::new(Arc::new(store));
1709        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1710        sampler.pause()?;
1711        sampler.pause()?;
1712        // TODO flush trace
1713        sampler.resume()?;
1714        let (ok, _) = sampler.abort()?;
1715        if let Some(err) = ok {
1716            Err(err)?;
1717        }
1718
1719        let store = MemoryStore::new();
1720        let zarr_config = ZarrConfig::new(Arc::new(store));
1721        let model = CpuModel::new(logp.clone());
1722        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1723        sampler.pause()?;
1724        if let (Some(err), _) = sampler.abort()? {
1725            Err(err)?;
1726        }
1727
1728        let store = MemoryStore::new();
1729        let zarr_config = ZarrConfig::new(Arc::new(store));
1730        let model = CpuModel::new(logp.clone());
1731        let start = Instant::now();
1732        let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1733
1734        let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1735            super::SamplerWaitResult::Trace(_) => {
1736                dbg!(start.elapsed());
1737                panic!("finished");
1738            }
1739            super::SamplerWaitResult::Timeout(sampler) => sampler,
1740            super::SamplerWaitResult::Err(_, _) => {
1741                panic!("error")
1742            }
1743        };
1744
1745        for _ in 0..30 {
1746            sampler.progress()?;
1747        }
1748
1749        match sampler.wait_timeout(Duration::from_secs(1)) {
1750            super::SamplerWaitResult::Trace(_) => {
1751                dbg!(start.elapsed());
1752            }
1753            super::SamplerWaitResult::Timeout(_) => {
1754                panic!("timeout")
1755            }
1756            super::SamplerWaitResult::Err(err, _) => Err(err)?,
1757        };
1758
1759        Ok(())
1760    }
1761
1762    #[test]
1763    fn sample_seq() {
1764        let logp = NormalLogp { dim: 10, mu: 0.1 };
1765        let math = CpuMath::new(&logp);
1766        let settings = DiagNutsSettings {
1767            num_tune: 100,
1768            num_draws: 100,
1769            ..Default::default()
1770        };
1771        let start = vec![0.2; 10];
1772
1773        let mut rng = StdRng::seed_from_u64(42);
1774
1775        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1776        let mut draws = chain.collect_vec();
1777        assert_eq!(draws.len(), 200);
1778
1779        let draw0 = draws.remove(100).unwrap();
1780        let (vals, stats) = draw0;
1781        assert_eq!(vals.len(), 10);
1782        assert_eq!(stats.chain, 1);
1783        assert_eq!(stats.draw, 100);
1784    }
1785}