Skip to main content

nuts_rs/
sampler.rs

1//! High-level sampler entry points: `Settings` presets, the parallel `Sampler`,
2//! and `sample_sequentially` for running one or many chains.
3
4use anyhow::Result;
5use nuts_storable::{HasDims, Storable, Value};
6use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::{collections::HashMap, fmt::Debug, time::Duration};
9
10#[cfg(feature = "parallel")]
11use anyhow::{Context, bail};
12#[cfg(feature = "parallel")]
13use itertools::Itertools;
14#[cfg(feature = "parallel")]
15use std::ops::Deref;
16
17#[cfg(feature = "parallel")]
18use rayon::{ScopeFifo, ThreadPoolBuilder};
19#[cfg(feature = "parallel")]
20use std::{
21    sync::{
22        Arc, Mutex,
23        mpsc::{
24            Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
25        },
26    },
27    thread::{JoinHandle, spawn},
28    time::Instant,
29};
30
31use crate::{
32    DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
33    adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
34    chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
35    dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
36    external_adapt_strategy::{ExternalTransformAdaptation, FlowSettings},
37    mclmc::MclmcTrajectoryKind,
38    nuts::NutsOptions,
39    sampler_stats::{SamplerStats, StatsDims},
40    transform::{
41        DiagAdaptStrategy, DiagMassMatrix, ExternalTransformation, LowRankMassMatrix,
42        LowRankMassMatrixStrategy, LowRankSettings,
43    },
44};
45
46#[cfg(feature = "parallel")]
47use crate::{model::Model, storage::{ChainStorage, StorageConfig, TraceStorage}};
48
49/// All sampler configurations implement this trait
50pub trait Settings:
51    private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + DeserializeOwned + 'static
52{
53    type Chain<M: Math>: Chain<M>;
54
55    fn new_chain<M: Math, R: Rng + ?Sized>(
56        &self,
57        chain: u64,
58        math: M,
59        rng: &mut R,
60    ) -> Self::Chain<M>;
61
62    fn hint_num_tune(&self) -> usize;
63    fn hint_num_draws(&self) -> usize;
64    fn num_chains(&self) -> usize;
65    fn seed(&self) -> u64;
66    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
67    fn sampler_name(&self) -> &'static str;
68    fn adaptation_name(&self) -> &'static str;
69
70    fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
71        let dims = StatsDims::from(math);
72        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
73            .into_iter()
74            .map(String::from)
75            .collect()
76    }
77
78    fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
79        <M::ExpandedVector as Storable<_>>::names(math)
80            .into_iter()
81            .map(String::from)
82            .collect()
83    }
84
85    fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
86        self.stat_names(math)
87            .into_iter()
88            .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
89            .collect()
90    }
91
92    fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
93        let dims = StatsDims::from(math);
94        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
95    }
96
97    fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
98        self.data_names(math)
99            .into_iter()
100            .map(|name| (name.clone(), self.data_type(math, &name)))
101            .collect()
102    }
103    fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
104        <M::ExpandedVector as Storable<_>>::item_type(math, name)
105    }
106
107    fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
108        self.stat_names(math)
109            .into_iter()
110            .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
111            .collect()
112    }
113
114    fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
115        let dims = StatsDims::from(math);
116        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
117            .into_iter()
118            .map(String::from)
119            .collect()
120    }
121
122    fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
123        let dims = StatsDims::from(math);
124        dims.dim_sizes()
125    }
126
127    fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
128        self.data_names(math)
129            .into_iter()
130            .map(|name| (name.clone(), self.data_dims(math, &name)))
131            .collect()
132    }
133
134    fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
135        <M::ExpandedVector as Storable<_>>::dims(math, name)
136            .into_iter()
137            .map(String::from)
138            .collect()
139    }
140
141    fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
142        let dims = StatsDims::from(math);
143        dims.coords()
144    }
145
146    fn stat_event_dims<M: Math>(&self, math: &M) -> Vec<(String, Option<String>)> {
147        let dims = StatsDims::from(math);
148        self.stat_names(math)
149            .into_iter()
150            .map(|name| {
151                let event_dim =
152                    <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::event_dim(
153                        &dims, &name,
154                    )
155                    .map(String::from);
156                (name, event_dim)
157            })
158            .collect()
159    }
160}
161
162#[derive(Debug, Clone)]
163#[non_exhaustive]
164pub struct Progress {
165    pub draw: u64,
166    pub chain: u64,
167    pub diverging: bool,
168    pub tuning: bool,
169    pub step_size: f64,
170    pub num_steps: u64,
171}
172
173mod private {
174    use super::{
175        DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, FlowNutsSettings,
176        LowRankMclmcSettings, LowRankNutsSettings,
177    };
178
179    pub trait Sealed {}
180
181    impl Sealed for DiagNutsSettings {}
182
183    impl Sealed for LowRankNutsSettings {}
184
185    impl Sealed for FlowNutsSettings {}
186
187    impl Sealed for DiagMclmcSettings {}
188
189    impl Sealed for LowRankMclmcSettings {}
190
191    impl Sealed for FlowMclmcSettings {}
192}
193
194/// Settings for the NUTS sampler
195#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
196pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
197    /// The number of tuning steps, where we fit the step size and geometry.
198    pub num_tune: u64,
199    /// The number of draws after tuning
200    pub num_draws: u64,
201    /// The maximum tree depth during sampling. The number of leapfrog steps
202    /// is smaller than 2 ^ maxdepth.
203    pub maxdepth: u64,
204    /// The minimum tree depth during sampling. The number of leapfrog steps
205    /// is larger than 2 ^ mindepth.
206    pub mindepth: u64,
207    /// Store the gradient in the SampleStats
208    pub store_gradient: bool,
209    /// Store each unconstrained parameter vector in the sampler stats
210    pub store_unconstrained: bool,
211    /// Store the transformed gradient and value in the sampler stats
212    pub store_transformed: bool,
213    /// If the energy error is larger than this threshold we treat the leapfrog
214    /// step as a divergence.
215    pub max_energy_error: f64,
216    /// Store detailed information about each divergence in the sampler stats
217    pub store_divergences: bool,
218    /// Settings for geometry adaptation.
219    pub adapt_options: A,
220    pub check_turning: bool,
221    pub target_integration_time: Option<f64>,
222    /// Selects the kinetic-energy form and the corresponding integrator.
223    ///
224    /// - [`KineticEnergyKind::Euclidean`]: standard leapfrog (default for most settings).
225    /// - [`KineticEnergyKind::ExactNormal`]: geodesic leapfrog exact for a standard-normal
226    ///   potential.
227    /// - [`KineticEnergyKind::Microcanonical`]: isokinetic ESH-dynamics leapfrog (microcanonical
228    ///   HMC); momentum is constrained to the unit sphere.
229    pub trajectory_kind: KineticEnergyKind,
230    pub num_chains: usize,
231    pub seed: u64,
232    /// Number of extra doublings to perform after reaching maxdepth. This can
233    /// be used to increase the effective sample size at the cost of more
234    /// expensive sampling.
235    pub extra_doublings: u64,
236}
237
238pub type DiagNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
239/// Backwards-compatible alias for [`DiagNutsSettings`].
240#[deprecated(since = "0.0.0", note = "Use DiagNutsSettings instead")]
241pub type DiagGradNutsSettings = DiagNutsSettings;
242pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
243pub type FlowNutsSettings = NutsSettings<FlowSettings>;
244/// Backwards-compatible alias for [`FlowNutsSettings`].
245#[deprecated(since = "0.0.0", note = "Use FlowNutsSettings instead")]
246pub type TransformedNutsSettings = FlowNutsSettings;
247
248/// Settings for the unadjusted Microcanonical Langevin Monte Carlo (MCLMC) sampler.
249///
250/// > ⚠️ **Experimental — use with caution**: The MCLMC sampler and all of its
251/// > variants are highly experimental. They have not been thoroughly validated
252/// > and may **not return correct posteriors**. The API, defaults, and
253/// > adaptation behaviour are all subject to breaking changes at any time.
254/// > Do not use these samplers in production or for results you rely on.
255///
256/// Step size `ε` and momentum decoherence length `L` are **constants** — no
257/// adaptation of those is performed yet. The geometry is adapted during
258/// warmup using the sampler-specific adaptation strategy, while the step size
259/// remains fixed.
260///
261/// Use the type aliases [`DiagMclmcSettings`], [`LowRankMclmcSettings`], and
262/// [`FlowMclmcSettings`] for concrete configurations.
263#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
264pub struct MclmcSettings<A: Debug + Copy + Default + Serialize> {
265    /// Step size ε for the ESH leapfrog integrator.
266    pub step_size: f64,
267    /// Momentum decoherence length L (controls partial momentum refresh rate).
268    /// Set to `f64::INFINITY` to disable momentum refresh entirely.
269    pub momentum_decoherence_length: f64,
270    /// Number of warmup draws.
271    pub num_tune: u64,
272    /// Number of sampling draws after warmup.
273    pub num_draws: u64,
274    /// Number of parallel chains.
275    pub num_chains: usize,
276    /// RNG seed.
277    pub seed: u64,
278    /// Maximum energy error before a step is flagged as a divergence.
279    pub max_energy_error: f64,
280    /// Store each unconstrained parameter vector in the sampler stats.
281    pub store_unconstrained: bool,
282    /// Store the gradient in the sampler stats.
283    pub store_gradient: bool,
284    /// Store the transformed gradient and value in the sampler stats
285    pub store_transformed: bool,
286    /// Store detailed information about each divergence in the sampler stats
287    pub store_divergences: bool,
288    /// Geometry adaptation options (step-size fields are ignored for Euclidean settings).
289    pub adapt_options: A,
290    /// Number of leapfrog steps per draw as a fraction of `L / ε`.
291    ///
292    /// The number of leapfrog steps between collector calls is:
293    /// `round(subsample_frequency * L / ε).max(1)`
294    ///
295    /// - `1.0` (default) — one sample per full trajectory (at the final step).
296    /// - `0.0` — every leapfrog step.
297    /// - Values in between space samples as a fraction of the decoherence
298    ///   length, so the interval scales naturally when `L` or `ε` changes.
299    pub subsample_frequency: f64,
300    /// When `true`, use the tree-structured step size retry on divergence:
301    /// halve the step size factor and try 2 steps before doubling back.
302    /// `log_weight` will include `log(step_size)` to correct for the varying
303    /// sampling density. When `false`, divergences are recorded immediately
304    /// without any retry and `log_weight = -energy_change`.
305    pub dynamic_step_size: bool,
306    /// Selects which leapfrog integrator and partial-momentum-refresh style
307    /// to use.  See [`MclmcTrajectoryKind`] for the available options.
308    /// Default: [`MclmcTrajectoryKind::Microcanonical`] (original MCLMC).
309    pub trajectory_kind: MclmcTrajectoryKind,
310    /// Fraction of `num_tune` draws at which the trajectory is switched from
311    /// Euclidean to Microcanonical when
312    /// `trajectory_kind == MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical`.
313    /// Ignored for other trajectory kinds.  Default: `0.3`.
314    pub trajectory_switch_fraction: f64,
315}
316
317/// MCLMC settings with a diagonal mass matrix adaptation.
318///
319/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
320/// > of the returned posteriors has not been verified. May change at any time.
321pub type DiagMclmcSettings = MclmcSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
322/// MCLMC settings with a low-rank mass matrix adaptation.
323///
324/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
325/// > of the returned posteriors has not been verified. May change at any time.
326pub type LowRankMclmcSettings = MclmcSettings<EuclideanAdaptOptions<LowRankSettings>>;
327/// MCLMC settings with a learned flow transformation.
328///
329/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
330/// > of the returned posteriors has not been verified. May change at any time.
331pub type FlowMclmcSettings = MclmcSettings<FlowSettings>;
332/// Backwards-compatible alias for [`FlowMclmcSettings`].
333#[deprecated(since = "0.0.0", note = "Use FlowMclmcSettings instead")]
334pub type TransformedMclmcSettings = FlowMclmcSettings;
335
336fn usize_hint(value: u64, field: &str) -> usize {
337    value
338        .try_into()
339        .unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
340}
341
342fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
343    adapt_options: A,
344    num_tune: u64,
345    num_chains: usize,
346    max_energy_error: f64,
347) -> MclmcSettings<A> {
348    MclmcSettings {
349        step_size: 0.5,
350        momentum_decoherence_length: 3.0,
351        num_tune,
352        num_draws: 1000,
353        num_chains,
354        seed: 0,
355        max_energy_error,
356        store_unconstrained: false,
357        store_gradient: false,
358        store_divergences: false,
359        store_transformed: false,
360        adapt_options,
361        subsample_frequency: 1.0,
362        dynamic_step_size: true,
363        trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
364        trajectory_switch_fraction: 0.3,
365    }
366}
367
368impl Default for DiagMclmcSettings {
369    fn default() -> Self {
370        let mut adapt_options = EuclideanAdaptOptions::default();
371        adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
372        default_mclmc_settings(adapt_options, 400, 6, 1000.0)
373    }
374}
375
376impl Default for LowRankMclmcSettings {
377    fn default() -> Self {
378        let mut adapt_options = EuclideanAdaptOptions::default();
379        adapt_options.early_mass_matrix_switch_freq = 20;
380        adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
381        default_mclmc_settings(adapt_options, 800, 6, 1000.0)
382    }
383}
384
385impl Default for FlowMclmcSettings {
386    fn default() -> Self {
387        default_mclmc_settings(FlowSettings::default(), 1500, 1, 20.0)
388    }
389}
390
391type DiagMclmcChain<M> = crate::mclmc::MclmcChain<
392    M,
393    ChaCha8Rng,
394    GlobalStrategy<M, DiagAdaptStrategy<M>>,
395    DiagMassMatrix<M>,
396>;
397type LowRankMclmcChain<M> = crate::mclmc::MclmcChain<
398    M,
399    ChaCha8Rng,
400    GlobalStrategy<M, LowRankMassMatrixStrategy>,
401    LowRankMassMatrix<M>,
402>;
403
404impl Settings for DiagMclmcSettings {
405    type Chain<M: Math> = DiagMclmcChain<M>;
406
407    fn new_chain<M: Math, R: Rng + ?Sized>(
408        &self,
409        chain: u64,
410        mut math: M,
411        rng: &mut R,
412    ) -> Self::Chain<M> {
413        use crate::dynamics::KineticEnergyKind;
414        use crate::mclmc::MclmcChain;
415        use crate::stepsize::StepSizeAdaptMethod;
416
417        let num_tune = self.num_tune;
418        let mut adapt_options = self.adapt_options;
419        adapt_options.step_size_settings.adapt_options.method =
420            StepSizeAdaptMethod::Fixed(self.step_size);
421        let strategy = GlobalStrategy::<M, DiagAdaptStrategy<M>>::new(
422            &mut math,
423            adapt_options,
424            num_tune,
425            chain,
426        );
427        let mass_matrix = DiagMassMatrix::new(
428            &mut math,
429            self.adapt_options.mass_matrix_options.store_mass_matrix,
430        );
431        let initial_kind = match self.trajectory_kind {
432            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
433            MclmcTrajectoryKind::Euclidean
434            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
435        };
436        let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
437        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
438        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
439        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
440        let stats_options = self.stats_options::<M>();
441        MclmcChain::new(
442            math,
443            hamiltonian,
444            strategy,
445            rng,
446            chain,
447            self.subsample_frequency,
448            self.dynamic_step_size,
449            self.trajectory_kind,
450            switch_draw,
451            self.max_energy_error,
452            stats_options,
453        )
454    }
455
456    fn hint_num_tune(&self) -> usize {
457        usize_hint(self.num_tune, "num_tune")
458    }
459
460    fn hint_num_draws(&self) -> usize {
461        usize_hint(self.num_draws, "num_draws")
462    }
463
464    fn num_chains(&self) -> usize {
465        self.num_chains
466    }
467
468    fn seed(&self) -> u64 {
469        self.seed
470    }
471
472    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
473        StatOptions {
474            adapt: GlobalStrategyStatsOptions {
475                step_size: (),
476                mass_matrix: (),
477            },
478            hamiltonian: -1,
479            point: {
480                let store_gradient = self.store_gradient;
481                let store_unconstrained = self.store_unconstrained;
482                let store_transformed = self.store_transformed;
483                TransformedPointStatsOptions {
484                    store_gradient,
485                    store_unconstrained,
486                    store_transformed,
487                }
488            },
489            divergence: crate::dynamics::DivergenceStatsOptions {
490                store_divergences: self.store_divergences,
491            },
492        }
493    }
494
495    fn sampler_name(&self) -> &'static str {
496        "mclmc"
497    }
498
499    fn adaptation_name(&self) -> &'static str {
500        "diagonal"
501    }
502}
503
504fn default_nuts_settings<A: Debug + Copy + Default + Serialize>(
505    adapt_options: A,
506    num_tune: u64,
507    num_chains: usize,
508    max_energy_error: f64,
509) -> NutsSettings<A> {
510    NutsSettings {
511        num_tune,
512        num_draws: 1000,
513        maxdepth: 10,
514        mindepth: 0,
515        max_energy_error,
516        store_gradient: false,
517        store_unconstrained: false,
518        store_transformed: false,
519        store_divergences: false,
520        adapt_options,
521        check_turning: true,
522        seed: 0,
523        num_chains,
524        target_integration_time: None,
525        trajectory_kind: KineticEnergyKind::Euclidean,
526        extra_doublings: 0,
527    }
528}
529
530impl Settings for LowRankMclmcSettings {
531    type Chain<M: Math> = LowRankMclmcChain<M>;
532
533    fn new_chain<M: Math, R: Rng + ?Sized>(
534        &self,
535        chain: u64,
536        mut math: M,
537        rng: &mut R,
538    ) -> Self::Chain<M> {
539        use crate::dynamics::KineticEnergyKind;
540        use crate::mclmc::MclmcChain;
541        use crate::stepsize::StepSizeAdaptMethod;
542
543        let num_tune = self.num_tune;
544        let mut adapt_options = self.adapt_options;
545        adapt_options.step_size_settings.adapt_options.method =
546            StepSizeAdaptMethod::Fixed(self.step_size);
547        let strategy = GlobalStrategy::<M, LowRankMassMatrixStrategy>::new(
548            &mut math,
549            adapt_options,
550            num_tune,
551            chain,
552        );
553        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
554        let initial_kind = match self.trajectory_kind {
555            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
556            MclmcTrajectoryKind::Euclidean
557            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
558        };
559        let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
560        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
561        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
562        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
563        let stats_options = self.stats_options::<M>();
564        MclmcChain::new(
565            math,
566            hamiltonian,
567            strategy,
568            rng,
569            chain,
570            self.subsample_frequency,
571            self.dynamic_step_size,
572            self.trajectory_kind,
573            switch_draw,
574            self.max_energy_error,
575            stats_options,
576        )
577    }
578
579    fn hint_num_tune(&self) -> usize {
580        usize_hint(self.num_tune, "num_tune")
581    }
582
583    fn hint_num_draws(&self) -> usize {
584        usize_hint(self.num_draws, "num_draws")
585    }
586
587    fn num_chains(&self) -> usize {
588        self.num_chains
589    }
590
591    fn seed(&self) -> u64 {
592        self.seed
593    }
594
595    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
596        StatOptions {
597            adapt: GlobalStrategyStatsOptions {
598                step_size: (),
599                mass_matrix: (),
600            },
601            hamiltonian: -1,
602            point: {
603                let store_gradient = self.store_gradient;
604                let store_unconstrained = self.store_unconstrained;
605                let store_transformed = self.store_transformed;
606                TransformedPointStatsOptions {
607                    store_gradient,
608                    store_unconstrained,
609                    store_transformed,
610                }
611            },
612            divergence: crate::dynamics::DivergenceStatsOptions {
613                store_divergences: self.store_divergences,
614            },
615        }
616    }
617
618    fn sampler_name(&self) -> &'static str {
619        "mclmc"
620    }
621
622    fn adaptation_name(&self) -> &'static str {
623        "low_rank"
624    }
625}
626
627impl Default for DiagNutsSettings {
628    fn default() -> Self {
629        default_nuts_settings(EuclideanAdaptOptions::default(), 400, 6, 1000.0)
630    }
631}
632
633impl Default for LowRankNutsSettings {
634    fn default() -> Self {
635        let mut vals = default_nuts_settings(EuclideanAdaptOptions::default(), 800, 6, 1000.0);
636        vals.adapt_options.mass_matrix_update_freq = 20;
637        vals
638    }
639}
640
641impl Default for FlowNutsSettings {
642    fn default() -> Self {
643        default_nuts_settings(FlowSettings::default(), 1500, 1, 20.0)
644    }
645}
646
647type DiagNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, DiagAdaptStrategy<M>>>;
648type LowRankNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
649
650fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>) -> NutsOptions {
651    NutsOptions {
652        maxdepth: settings.maxdepth,
653        mindepth: settings.mindepth,
654        store_divergences: settings.store_divergences,
655        check_turning: settings.check_turning,
656        target_integration_time: settings.target_integration_time,
657        extra_doublings: settings.extra_doublings,
658        max_energy_error: settings.max_energy_error,
659    }
660}
661
662impl Settings for LowRankNutsSettings {
663    type Chain<M: Math> = LowRankNutsChain<M>;
664
665    fn new_chain<M: Math, R: Rng + ?Sized>(
666        &self,
667        chain: u64,
668        mut math: M,
669        mut rng: &mut R,
670    ) -> Self::Chain<M> {
671        let num_tune = self.num_tune;
672        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
673        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
674        let hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
675
676        let options = nuts_options(self);
677
678        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
679
680        NutsChain::new(
681            math,
682            hamiltonian,
683            strategy,
684            options,
685            rng,
686            chain,
687            self.stats_options(),
688        )
689    }
690
691    fn hint_num_tune(&self) -> usize {
692        usize_hint(self.num_tune, "num_tune")
693    }
694
695    fn hint_num_draws(&self) -> usize {
696        usize_hint(self.num_draws, "num_draws")
697    }
698
699    fn num_chains(&self) -> usize {
700        self.num_chains
701    }
702
703    fn seed(&self) -> u64 {
704        self.seed
705    }
706
707    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
708        StatOptions {
709            adapt: GlobalStrategyStatsOptions {
710                mass_matrix: (),
711                step_size: (),
712            },
713            hamiltonian: -1,
714            point: {
715                let store_gradient = self.store_gradient;
716                let store_unconstrained = self.store_unconstrained;
717                let store_transformed = self.store_transformed;
718                TransformedPointStatsOptions {
719                    store_gradient,
720                    store_unconstrained,
721                    store_transformed,
722                }
723            },
724            divergence: crate::dynamics::DivergenceStatsOptions {
725                store_divergences: self.store_divergences,
726            },
727        }
728    }
729
730    fn sampler_name(&self) -> &'static str {
731        "nuts"
732    }
733
734    fn adaptation_name(&self) -> &'static str {
735        "low_rank"
736    }
737}
738
739impl Settings for DiagNutsSettings {
740    type Chain<M: Math> = DiagNutsChain<M>;
741
742    fn new_chain<M: Math, R: Rng + ?Sized>(
743        &self,
744        chain: u64,
745        mut math: M,
746        mut rng: &mut R,
747    ) -> Self::Chain<M> {
748        let num_tune = self.num_tune;
749        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
750        let mass_matrix = DiagMassMatrix::new(
751            &mut math,
752            self.adapt_options.mass_matrix_options.store_mass_matrix,
753        );
754        let potential = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
755
756        let options = nuts_options(self);
757
758        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
759
760        NutsChain::new(
761            math,
762            potential,
763            strategy,
764            options,
765            rng,
766            chain,
767            self.stats_options(),
768        )
769    }
770
771    fn hint_num_tune(&self) -> usize {
772        usize_hint(self.num_tune, "num_tune")
773    }
774
775    fn hint_num_draws(&self) -> usize {
776        usize_hint(self.num_draws, "num_draws")
777    }
778
779    fn num_chains(&self) -> usize {
780        self.num_chains
781    }
782
783    fn seed(&self) -> u64 {
784        self.seed
785    }
786
787    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
788        StatOptions {
789            adapt: GlobalStrategyStatsOptions {
790                mass_matrix: (),
791                step_size: (),
792            },
793            hamiltonian: -1,
794            point: {
795                let store_gradient = self.store_gradient;
796                let store_unconstrained = self.store_unconstrained;
797                let store_transformed = self.store_transformed;
798                TransformedPointStatsOptions {
799                    store_gradient,
800                    store_unconstrained,
801                    store_transformed,
802                }
803            },
804            divergence: crate::dynamics::DivergenceStatsOptions {
805                store_divergences: self.store_divergences,
806            },
807        }
808    }
809
810    fn sampler_name(&self) -> &'static str {
811        "nuts"
812    }
813
814    fn adaptation_name(&self) -> &'static str {
815        "diagonal"
816    }
817}
818
819impl Settings for FlowNutsSettings {
820    type Chain<M: Math> = NutsChain<M, ChaCha8Rng, ExternalTransformAdaptation>;
821
822    fn new_chain<M: Math, R: Rng + ?Sized>(
823        &self,
824        chain: u64,
825        mut math: M,
826        mut rng: &mut R,
827    ) -> Self::Chain<M> {
828        let num_tune = self.num_tune;
829
830        let strategy =
831            ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
832        let params = math
833            .new_transformation(rng, math.dim(), chain)
834            .expect("Failed to create external transformation");
835        let transform = ExternalTransformation::new(params);
836        let hamiltonian = TransformedHamiltonian::new(&mut math, transform, self.trajectory_kind);
837
838        let options = nuts_options(self);
839
840        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
841        NutsChain::new(
842            math,
843            hamiltonian,
844            strategy,
845            options,
846            rng,
847            chain,
848            self.stats_options(),
849        )
850    }
851
852    fn hint_num_tune(&self) -> usize {
853        usize_hint(self.num_tune, "num_tune")
854    }
855
856    fn hint_num_draws(&self) -> usize {
857        usize_hint(self.num_draws, "num_draws")
858    }
859
860    fn num_chains(&self) -> usize {
861        self.num_chains
862    }
863
864    fn seed(&self) -> u64 {
865        self.seed
866    }
867
868    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
869        StatOptions {
870            adapt: (),
871            hamiltonian: (),
872            point: {
873                let store_gradient = self.store_gradient;
874                let store_unconstrained = self.store_unconstrained;
875                let store_transformed = self.store_transformed;
876                TransformedPointStatsOptions {
877                    store_gradient,
878                    store_unconstrained,
879                    store_transformed,
880                }
881            },
882            divergence: crate::dynamics::DivergenceStatsOptions {
883                store_divergences: self.store_divergences,
884            },
885        }
886    }
887
888    fn sampler_name(&self) -> &'static str {
889        "nuts"
890    }
891
892    fn adaptation_name(&self) -> &'static str {
893        "flow"
894    }
895}
896
897impl Settings for FlowMclmcSettings {
898    type Chain<M: Math> = crate::mclmc::MclmcChain<
899        M,
900        ChaCha8Rng,
901        ExternalTransformAdaptation,
902        ExternalTransformation<M>,
903    >;
904
905    fn new_chain<M: Math, R: Rng + ?Sized>(
906        &self,
907        chain: u64,
908        mut math: M,
909        rng: &mut R,
910    ) -> Self::Chain<M> {
911        use crate::dynamics::KineticEnergyKind;
912        use crate::mclmc::MclmcChain;
913
914        let num_tune = self.num_tune;
915        let strategy =
916            ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
917        let params = math
918            .new_transformation(rng, math.dim(), chain)
919            .expect("Failed to create external transformation");
920        let transform = ExternalTransformation::new(params);
921        let initial_kind = match self.trajectory_kind {
922            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
923            MclmcTrajectoryKind::Euclidean
924            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
925        };
926        let mut hamiltonian = TransformedHamiltonian::new(&mut math, transform, initial_kind);
927        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
928        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
929        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
930        let stats_options = self.stats_options::<M>();
931        MclmcChain::new(
932            math,
933            hamiltonian,
934            strategy,
935            rng,
936            chain,
937            self.subsample_frequency,
938            self.dynamic_step_size,
939            self.trajectory_kind,
940            switch_draw,
941            self.max_energy_error,
942            stats_options,
943        )
944    }
945
946    fn hint_num_tune(&self) -> usize {
947        usize_hint(self.num_tune, "num_tune")
948    }
949
950    fn hint_num_draws(&self) -> usize {
951        usize_hint(self.num_draws, "num_draws")
952    }
953
954    fn num_chains(&self) -> usize {
955        self.num_chains
956    }
957
958    fn seed(&self) -> u64 {
959        self.seed
960    }
961
962    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
963        StatOptions {
964            adapt: (),
965            hamiltonian: (),
966            point: {
967                let store_gradient = self.store_gradient;
968                let store_unconstrained = self.store_unconstrained;
969                let store_transformed = self.store_transformed;
970                TransformedPointStatsOptions {
971                    store_gradient,
972                    store_unconstrained,
973                    store_transformed,
974                }
975            },
976            divergence: crate::dynamics::DivergenceStatsOptions {
977                store_divergences: self.store_divergences,
978            },
979        }
980    }
981
982    fn sampler_name(&self) -> &'static str {
983        "mclmc"
984    }
985
986    fn adaptation_name(&self) -> &'static str {
987        "flow"
988    }
989}
990
991pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
992    math: M,
993    settings: DiagNutsSettings,
994    start: &[f64],
995    draws: u64,
996    chain: u64,
997    rng: &mut R,
998) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
999    let mut sampler = settings.new_chain(chain, math, rng);
1000    sampler.set_position(start)?;
1001    Ok((0..draws).map(move |_| sampler.draw()))
1002}
1003
1004#[non_exhaustive]
1005#[derive(Clone, Debug)]
1006pub struct ChainProgress {
1007    pub finished_draws: usize,
1008    pub total_draws: usize,
1009    pub divergences: usize,
1010    pub tuning: bool,
1011    pub started: bool,
1012    pub latest_num_steps: usize,
1013    pub total_num_steps: usize,
1014    pub step_size: f64,
1015    pub runtime: Duration,
1016    pub divergent_draws: Vec<usize>,
1017}
1018
1019impl ChainProgress {
1020    fn new(total: usize) -> Self {
1021        Self {
1022            finished_draws: 0,
1023            total_draws: total,
1024            divergences: 0,
1025            tuning: true,
1026            started: false,
1027            latest_num_steps: 0,
1028            step_size: 0f64,
1029            total_num_steps: 0,
1030            runtime: Duration::ZERO,
1031            divergent_draws: Vec::new(),
1032        }
1033    }
1034
1035    fn update(&mut self, stats: &Progress, draw_duration: Duration) {
1036        if stats.diverging & !stats.tuning {
1037            self.divergences += 1;
1038            self.divergent_draws.push(self.finished_draws);
1039        }
1040        self.finished_draws += 1;
1041        self.tuning = stats.tuning;
1042
1043        self.latest_num_steps = stats.num_steps as usize;
1044        self.total_num_steps += stats.num_steps as usize;
1045        self.step_size = stats.step_size;
1046        self.runtime += draw_duration;
1047    }
1048}
1049
1050#[cfg(feature = "parallel")]
1051enum ChainCommand {
1052    Resume,
1053    Pause,
1054}
1055
1056#[cfg(feature = "parallel")]
1057struct ChainProcess<T>
1058where
1059    T: TraceStorage,
1060{
1061    stop_marker: Sender<ChainCommand>,
1062    trace: Arc<Mutex<Option<T::ChainStorage>>>,
1063    progress: Arc<Mutex<ChainProgress>>,
1064}
1065
1066#[cfg(feature = "parallel")]
1067impl<T: TraceStorage> ChainProcess<T> {
1068    fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
1069        let finalized_chain_traces = chains
1070            .into_iter()
1071            .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
1072            .map(|chain| chain.finalize())
1073            .collect_vec();
1074        trace.finalize(finalized_chain_traces)
1075    }
1076
1077    fn progress(&self) -> ChainProgress {
1078        self.progress.lock().expect("Poisoned lock").clone()
1079    }
1080
1081    fn resume(&self) -> Result<()> {
1082        self.stop_marker.send(ChainCommand::Resume)?;
1083        Ok(())
1084    }
1085
1086    fn pause(&self) -> Result<()> {
1087        self.stop_marker.send(ChainCommand::Pause)?;
1088        Ok(())
1089    }
1090
1091    fn start<'model, M: Model, S: Settings>(
1092        model: &'model M,
1093        chain_trace: T::ChainStorage,
1094        chain_id: u64,
1095        seed: u64,
1096        settings: &'model S,
1097        scope: &ScopeFifo<'model>,
1098        results: Sender<Result<()>>,
1099    ) -> Result<Self> {
1100        let (stop_marker_tx, stop_marker_rx) = channel();
1101
1102        let mut rng = ChaCha8Rng::seed_from_u64(seed);
1103        rng.set_stream(chain_id + 1);
1104
1105        let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
1106        let progress = Arc::new(Mutex::new(ChainProgress::new(
1107            settings.hint_num_draws() + settings.hint_num_tune(),
1108        )));
1109
1110        let trace_inner = chain_trace.clone();
1111        let progress_inner = progress.clone();
1112
1113        scope.spawn_fifo(move |_| {
1114            let chain_trace = trace_inner;
1115            let progress = progress_inner;
1116
1117            let mut sample = move || {
1118                let logp = model
1119                    .math(&mut rng)
1120                    .context("Failed to create model density")?;
1121                let dim = logp.dim();
1122
1123                let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
1124
1125                progress.lock().expect("Poisoned mutex").started = true;
1126
1127                let mut initval = vec![0f64; dim];
1128                // TODO maxtries
1129                let mut error = None;
1130                for _ in 0..500 {
1131                    model
1132                        .init_position(&mut rng, &mut initval)
1133                        .context("Failed to generate a new initial position")?;
1134                    if let Err(err) = sampler.set_position(&initval) {
1135                        error = Some(err);
1136                        continue;
1137                    }
1138                    error = None;
1139                    break;
1140                }
1141
1142                if let Some(error) = error {
1143                    return Err(error.context("All initialization points failed"));
1144                }
1145
1146                let draws = settings.hint_num_tune() + settings.hint_num_draws();
1147
1148                let mut msg = stop_marker_rx.try_recv();
1149                let mut draw = 0;
1150                loop {
1151                    match msg {
1152                        // The remote end is dead
1153                        Err(TryRecvError::Disconnected) => {
1154                            break;
1155                        }
1156                        Err(TryRecvError::Empty) => {}
1157                        Ok(ChainCommand::Pause) => {
1158                            msg = stop_marker_rx.recv().map_err(|e| e.into());
1159                            continue;
1160                        }
1161                        Ok(ChainCommand::Resume) => {}
1162                    }
1163
1164                    let now = Instant::now();
1165                    let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
1166
1167                    let mut guard = chain_trace
1168                        .lock()
1169                        .expect("Could not unlock trace lock. Poisoned mutex");
1170
1171                    let Some(trace_val) = guard.as_mut() else {
1172                        // The trace was removed by controller thread. We can stop sampling
1173                        break;
1174                    };
1175                    progress
1176                        .lock()
1177                        .expect("Poisoned mutex")
1178                        .update(&info, now.elapsed());
1179
1180                    let math = sampler.math();
1181                    let dims = StatsDims::from(math.deref());
1182                    trace_val.record_sample(
1183                        settings,
1184                        stats.get_all(&dims),
1185                        draw_data.get_all(math.deref()),
1186                        &info,
1187                    )?;
1188
1189                    draw += 1;
1190                    if draw == draws {
1191                        break;
1192                    }
1193
1194                    msg = stop_marker_rx.try_recv();
1195                }
1196                Ok(())
1197            };
1198
1199            let result = sample();
1200
1201            // We intentionally ignore errors here, because this means some other
1202            // chain already failed, and should have reported the error.
1203            let _ = results.send(result);
1204            drop(results);
1205        });
1206
1207        Ok(Self {
1208            trace: chain_trace,
1209            stop_marker: stop_marker_tx,
1210            progress,
1211        })
1212    }
1213
1214    fn flush(&self) -> Result<()> {
1215        self.trace
1216            .lock()
1217            .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
1218            .context("Could not flush trace")?
1219            .as_mut()
1220            .map(|v| v.flush())
1221            .transpose()?;
1222        Ok(())
1223    }
1224}
1225
1226#[cfg(feature = "parallel")]
1227#[derive(Debug)]
1228enum SamplerCommand {
1229    Pause,
1230    Continue,
1231    Progress,
1232    Flush,
1233    Inspect,
1234}
1235
1236#[cfg(feature = "parallel")]
1237enum SamplerResponse<T: Send + 'static> {
1238    Ok(),
1239    Progress(Box<[ChainProgress]>),
1240    Inspect(T),
1241}
1242
1243#[cfg(feature = "parallel")]
1244pub enum SamplerWaitResult<F: Send + 'static> {
1245    Trace(F),
1246    Timeout(Sampler<F>),
1247    Err(anyhow::Error, Option<F>),
1248}
1249
1250#[cfg(feature = "parallel")]
1251pub struct Sampler<F: Send + 'static> {
1252    main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
1253    commands: SyncSender<SamplerCommand>,
1254    responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
1255    results: Receiver<Result<()>>,
1256}
1257
1258#[cfg(feature = "parallel")]
1259pub struct ProgressCallback {
1260    pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
1261    pub rate: Duration,
1262}
1263
1264#[cfg(feature = "parallel")]
1265impl<F: Send + 'static> Sampler<F> {
1266    pub fn new<M, S, C, T>(
1267        model: M,
1268        settings: S,
1269        trace_config: C,
1270        num_cores: usize,
1271        callback: Option<ProgressCallback>,
1272    ) -> Result<Self>
1273    where
1274        S: Settings,
1275        C: StorageConfig<Storage = T>,
1276        M: Model,
1277        T: TraceStorage<Finalized = F>,
1278    {
1279        let (commands_tx, commands_rx) = sync_channel(0);
1280        let (responses_tx, responses_rx) = sync_channel(0);
1281        let (results_tx, results_rx) = channel();
1282
1283        let main_thread = spawn(move || {
1284            let pool = ThreadPoolBuilder::new()
1285                .num_threads(num_cores + 1) // One more thread because the controller also uses one
1286                .thread_name(|i| format!("nutpie-worker-{i}"))
1287                .build()
1288                .context("Could not start thread pool")?;
1289
1290            let settings_ref = &settings;
1291            let model_ref = &model;
1292            let mut callback = callback;
1293
1294            pool.scope_fifo(move |scope| {
1295                let results = results_tx;
1296                let mut chains = Vec::with_capacity(settings.num_chains());
1297
1298                let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
1299                rng.set_stream(0);
1300
1301                let math = model_ref
1302                    .math(&mut rng)
1303                    .context("Could not create model density")?;
1304                let trace = trace_config
1305                    .new_trace(settings_ref, &math)
1306                    .context("Could not create trace object")?;
1307                drop(math);
1308
1309                for chain_id in 0..settings.num_chains() {
1310                    let chain_trace_val = trace
1311                        .initialize_trace_for_chain(chain_id as u64)
1312                        .context("Failed to create trace object")?;
1313                    let chain = ChainProcess::start(
1314                        model_ref,
1315                        chain_trace_val,
1316                        chain_id as u64,
1317                        settings.seed(),
1318                        settings_ref,
1319                        scope,
1320                        results.clone(),
1321                    );
1322                    chains.push(chain);
1323                }
1324                drop(results);
1325
1326                let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
1327                if let Some(error) = errors.into_iter().next() {
1328                    let _ = ChainProcess::finalize_many(trace, chains);
1329                    return Err(error).context("Could not start chains");
1330                }
1331
1332                let mut main_loop = || {
1333                    let start_time = Instant::now();
1334                    let mut pause_start = Instant::now();
1335                    let mut pause_time = Duration::ZERO;
1336
1337                    let mut progress_rate = Duration::MAX;
1338                    if let Some(ProgressCallback { callback, rate }) = &mut callback {
1339                        let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
1340                        callback(start_time.elapsed(), progress.into());
1341                        progress_rate = *rate;
1342                    }
1343                    let mut last_progress = Instant::now();
1344                    let mut is_paused = false;
1345
1346                    loop {
1347                        let timeout = progress_rate.checked_sub(last_progress.elapsed());
1348                        let timeout = timeout.unwrap_or_else(|| {
1349                            if let Some(ProgressCallback { callback, .. }) = &mut callback {
1350                                let progress =
1351                                    chains.iter().map(|chain| chain.progress()).collect_vec();
1352                                let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
1353                                if is_paused {
1354                                    elapsed = elapsed.saturating_sub(pause_start.elapsed());
1355                                }
1356                                callback(elapsed, progress.into());
1357                            }
1358                            last_progress = Instant::now();
1359                            progress_rate
1360                        });
1361
1362                        // TODO return when all chains are done
1363                        match commands_rx.recv_timeout(timeout) {
1364                            Ok(SamplerCommand::Pause) => {
1365                                for chain in chains.iter() {
1366                                    // This failes if the thread is done.
1367                                    // We just want to ignore those threads.
1368                                    let _ = chain.pause();
1369                                }
1370                                if !is_paused {
1371                                    pause_start = Instant::now();
1372                                }
1373                                is_paused = true;
1374                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1375                                    anyhow::anyhow!(
1376                                        "Could not send pause response to controller thread: {e}"
1377                                    )
1378                                })?;
1379                            }
1380                            Ok(SamplerCommand::Continue) => {
1381                                for chain in chains.iter() {
1382                                    // This failes if the thread is done.
1383                                    // We just want to ignore those threads.
1384                                    let _ = chain.resume();
1385                                }
1386                                pause_time += pause_start.elapsed();
1387                                is_paused = false;
1388                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1389                                    anyhow::anyhow!(
1390                                        "Could not send continue response to controller thread: {e}"
1391                                    )
1392                                })?;
1393                            }
1394                            Ok(SamplerCommand::Progress) => {
1395                                let progress =
1396                                    chains.iter().map(|chain| chain.progress()).collect_vec();
1397                                responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
1398                                    anyhow::anyhow!(
1399                                        "Could not send progress response to controller thread: {e}"
1400                                    )
1401                                })?;
1402                            }
1403                            Ok(SamplerCommand::Inspect) => {
1404                                let traces = chains
1405                                    .iter()
1406                                    .filter_map(|chain| {
1407                                        chain
1408                                            .trace
1409                                            .lock()
1410                                            .expect("Poisoned lock")
1411                                            .as_ref()
1412                                            .map(|v| v.inspect())
1413                                    })
1414                                    .collect_vec();
1415                                let finalized_trace = trace.inspect(traces)?;
1416                                responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
1417                                    anyhow::anyhow!(
1418                                        "Could not send inspect response to controller thread: {e}"
1419                                    )
1420                                })?;
1421                            }
1422                            Ok(SamplerCommand::Flush) => {
1423                                for chain in chains.iter() {
1424                                    chain.flush()?;
1425                                }
1426                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1427                                    anyhow::anyhow!(
1428                                        "Could not send flush response to controller thread: {e}"
1429                                    )
1430                                })?;
1431                            }
1432                            Err(RecvTimeoutError::Timeout) => {}
1433                            Err(RecvTimeoutError::Disconnected) => {
1434                                if let Some(ProgressCallback { callback, .. }) = &mut callback {
1435                                    let progress =
1436                                        chains.iter().map(|chain| chain.progress()).collect_vec();
1437                                    let mut elapsed =
1438                                        start_time.elapsed().saturating_sub(pause_time);
1439                                    if is_paused {
1440                                        elapsed = elapsed.saturating_sub(pause_start.elapsed());
1441                                    }
1442                                    callback(elapsed, progress.into());
1443                                }
1444                                return Ok(());
1445                            }
1446                        };
1447                    }
1448                };
1449                let result: Result<()> = main_loop();
1450                // Run finalization even if something failed
1451                let output = ChainProcess::finalize_many(trace, chains)?;
1452
1453                result?;
1454                Ok(output)
1455            })
1456        });
1457
1458        Ok(Self {
1459            main_thread,
1460            commands: commands_tx,
1461            responses: responses_rx,
1462            results: results_rx,
1463        })
1464    }
1465
1466    pub fn pause(&mut self) -> Result<()> {
1467        self.commands
1468            .send(SamplerCommand::Pause)
1469            .context("Could not send pause command to controller thread")?;
1470        let response = self
1471            .responses
1472            .recv()
1473            .context("Could not recieve pause response from controller thread")?;
1474        let SamplerResponse::Ok() = response else {
1475            bail!("Got invalid response from sample controller thread");
1476        };
1477        Ok(())
1478    }
1479
1480    pub fn resume(&mut self) -> Result<()> {
1481        self.commands.send(SamplerCommand::Continue)?;
1482        let response = self.responses.recv()?;
1483        let SamplerResponse::Ok() = response else {
1484            bail!("Got invalid response from sample controller thread");
1485        };
1486        Ok(())
1487    }
1488
1489    pub fn flush(&mut self) -> Result<()> {
1490        self.commands.send(SamplerCommand::Flush)?;
1491        let response = self
1492            .responses
1493            .recv()
1494            .context("Could not recieve flush response from controller thread")?;
1495        let SamplerResponse::Ok() = response else {
1496            bail!("Got invalid response from sample controller thread");
1497        };
1498        Ok(())
1499    }
1500
1501    pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
1502        self.commands.send(SamplerCommand::Inspect)?;
1503        let response = self
1504            .responses
1505            .recv()
1506            .context("Could not recieve inspect response from controller thread")?;
1507        let SamplerResponse::Inspect(trace) = response else {
1508            bail!("Got invalid response from sample controller thread");
1509        };
1510        Ok(trace)
1511    }
1512
1513    pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
1514        drop(self.commands);
1515        let result = self.main_thread.join();
1516        match result {
1517            Err(payload) => std::panic::resume_unwind(payload),
1518            Ok(Ok(val)) => Ok(val),
1519            Ok(Err(err)) => Err(err),
1520        }
1521    }
1522
1523    pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
1524        let start = Instant::now();
1525        let mut remaining = Some(timeout);
1526        while remaining.is_some() {
1527            match self.results.recv_timeout(timeout) {
1528                Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
1529                Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
1530                Err(RecvTimeoutError::Disconnected) => match self.abort() {
1531                    Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
1532                    Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
1533                    Err(err) => return SamplerWaitResult::Err(err, None),
1534                },
1535                Err(RecvTimeoutError::Timeout) => break,
1536            }
1537        }
1538        SamplerWaitResult::Timeout(self)
1539    }
1540
1541    pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1542        self.commands.send(SamplerCommand::Progress)?;
1543        let response = self.responses.recv()?;
1544        let SamplerResponse::Progress(progress) = response else {
1545            bail!("Got invalid response from sample controller thread");
1546        };
1547        Ok(progress)
1548    }
1549}
1550
1551#[cfg(test)]
1552pub mod test_logps {
1553    #[cfg(feature = "zarr")]
1554    use crate::{Model, math::CpuLogpFunc, math::CpuMath};
1555    #[cfg(feature = "zarr")]
1556    use anyhow::Result;
1557    #[cfg(feature = "zarr")]
1558    use rand::Rng;
1559
1560    #[cfg(feature = "zarr")]
1561    pub struct CpuModel<F> {
1562        logp: F,
1563    }
1564
1565    #[cfg(feature = "zarr")]
1566    impl<F> CpuModel<F> {
1567        pub fn new(logp: F) -> Self {
1568            Self { logp }
1569        }
1570    }
1571
1572    #[cfg(feature = "zarr")]
1573    impl<F> Model for CpuModel<F>
1574    where
1575        F: Send + Sync + 'static,
1576        for<'a> &'a F: CpuLogpFunc,
1577    {
1578        type Math<'model> = CpuMath<&'model F>;
1579
1580        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1581            Ok(CpuMath::new(&self.logp))
1582        }
1583
1584        fn init_position<R: rand::prelude::Rng + ?Sized>(
1585            &self,
1586            _rng: &mut R,
1587            position: &mut [f64],
1588        ) -> Result<()> {
1589            position.iter_mut().for_each(|x| *x = 0.);
1590            Ok(())
1591        }
1592    }
1593}
1594
1595#[cfg(test)]
1596mod tests {
1597    use crate::math::test_logps::NormalLogp;
1598    use crate::{
1599        Chain, math::CpuMath, sample_sequentially, sampler::DiagMclmcSettings,
1600        sampler::DiagNutsSettings, sampler::LowRankMclmcSettings, sampler::LowRankNutsSettings,
1601        sampler::Settings,
1602    };
1603
1604    #[cfg(feature = "zarr")]
1605    use super::test_logps::CpuModel;
1606
1607    use anyhow::Result;
1608    use itertools::Itertools;
1609    use pretty_assertions::assert_eq;
1610    use rand::{SeedableRng, rngs::StdRng};
1611
1612    #[cfg(feature = "zarr")]
1613    use std::{
1614        sync::Arc,
1615        time::{Duration, Instant},
1616    };
1617
1618    #[cfg(feature = "zarr")]
1619    use crate::{Sampler, ZarrConfig};
1620
1621    #[cfg(feature = "zarr")]
1622    use zarrs::storage::store::MemoryStore;
1623
1624    fn assert_settings_smoke<S: Settings>(settings: S) -> Result<()> {
1625        let logp = NormalLogp { dim: 4, mu: 0.1 };
1626        let math = CpuMath::new(&logp);
1627        let mut rng = StdRng::seed_from_u64(42);
1628
1629        let stat_names = settings.stat_names(&math);
1630        let stat_types = settings.stat_types(&math);
1631        assert!(!stat_names.is_empty());
1632        assert_eq!(stat_names.len(), stat_types.len());
1633
1634        let mut chain = settings.new_chain(0, math, &mut rng);
1635        chain.set_position(&vec![0.2; 4])?;
1636        let (_draw, _info) = chain.draw()?;
1637        Ok(())
1638    }
1639
1640    #[test]
1641    fn all_settings_smoke() -> Result<()> {
1642        assert_settings_smoke(DiagNutsSettings {
1643            num_tune: 10,
1644            num_draws: 10,
1645            ..Default::default()
1646        })?;
1647        assert_settings_smoke(LowRankNutsSettings {
1648            num_tune: 10,
1649            num_draws: 10,
1650            ..Default::default()
1651        })?;
1652        assert_settings_smoke(DiagMclmcSettings {
1653            num_tune: 10,
1654            num_draws: 10,
1655            ..Default::default()
1656        })?;
1657        assert_settings_smoke(LowRankMclmcSettings {
1658            num_tune: 10,
1659            num_draws: 10,
1660            ..Default::default()
1661        })?;
1662        Ok(())
1663    }
1664
1665    #[test]
1666    fn sample_chain() -> Result<()> {
1667        let logp = NormalLogp { dim: 10, mu: 0.1 };
1668        let math = CpuMath::new(&logp);
1669        let settings = DiagNutsSettings {
1670            num_tune: 100,
1671            num_draws: 100,
1672            ..Default::default()
1673        };
1674        let start = vec![0.2; 10];
1675
1676        let mut rng = StdRng::seed_from_u64(42);
1677
1678        let mut chain = settings.new_chain(0, math, &mut rng);
1679
1680        let (_draw, info) = chain.draw()?;
1681        assert!(info.tuning);
1682        assert_eq!(info.draw, 0);
1683
1684        let math = CpuMath::new(&logp);
1685        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1686        let mut draws = chain.collect_vec();
1687        assert_eq!(draws.len(), 200);
1688
1689        let draw0 = draws.remove(100).unwrap();
1690        let (vals, stats) = draw0;
1691        assert_eq!(vals.len(), 10);
1692        assert_eq!(stats.chain, 1);
1693        assert_eq!(stats.draw, 100);
1694        Ok(())
1695    }
1696
1697    #[cfg(feature = "zarr")]
1698    #[test]
1699    fn sample_parallel() -> Result<()> {
1700        let logp = NormalLogp { dim: 100, mu: 0.1 };
1701        let settings = DiagNutsSettings {
1702            num_tune: 100,
1703            num_draws: 100,
1704            seed: 10,
1705            ..Default::default()
1706        };
1707
1708        let model = CpuModel::new(logp.clone());
1709        let store = MemoryStore::new();
1710
1711        let zarr_config = ZarrConfig::new(Arc::new(store));
1712        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1713        sampler.pause()?;
1714        sampler.pause()?;
1715        // TODO flush trace
1716        sampler.resume()?;
1717        let (ok, _) = sampler.abort()?;
1718        if let Some(err) = ok {
1719            Err(err)?;
1720        }
1721
1722        let store = MemoryStore::new();
1723        let zarr_config = ZarrConfig::new(Arc::new(store));
1724        let model = CpuModel::new(logp.clone());
1725        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1726        sampler.pause()?;
1727        if let (Some(err), _) = sampler.abort()? {
1728            Err(err)?;
1729        }
1730
1731        let store = MemoryStore::new();
1732        let zarr_config = ZarrConfig::new(Arc::new(store));
1733        let model = CpuModel::new(logp.clone());
1734        let start = Instant::now();
1735        let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1736
1737        let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1738            super::SamplerWaitResult::Trace(_) => {
1739                dbg!(start.elapsed());
1740                panic!("finished");
1741            }
1742            super::SamplerWaitResult::Timeout(sampler) => sampler,
1743            super::SamplerWaitResult::Err(_, _) => {
1744                panic!("error")
1745            }
1746        };
1747
1748        for _ in 0..30 {
1749            sampler.progress()?;
1750        }
1751
1752        match sampler.wait_timeout(Duration::from_secs(1)) {
1753            super::SamplerWaitResult::Trace(_) => {
1754                dbg!(start.elapsed());
1755            }
1756            super::SamplerWaitResult::Timeout(_) => {
1757                panic!("timeout")
1758            }
1759            super::SamplerWaitResult::Err(err, _) => Err(err)?,
1760        };
1761
1762        Ok(())
1763    }
1764
1765    #[test]
1766    fn sample_seq() {
1767        let logp = NormalLogp { dim: 10, mu: 0.1 };
1768        let math = CpuMath::new(&logp);
1769        let settings = DiagNutsSettings {
1770            num_tune: 100,
1771            num_draws: 100,
1772            ..Default::default()
1773        };
1774        let start = vec![0.2; 10];
1775
1776        let mut rng = StdRng::seed_from_u64(42);
1777
1778        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1779        let mut draws = chain.collect_vec();
1780        assert_eq!(draws.len(), 200);
1781
1782        let draw0 = draws.remove(100).unwrap();
1783        let (vals, stats) = draw0;
1784        assert_eq!(vals.len(), 10);
1785        assert_eq!(stats.chain, 1);
1786        assert_eq!(stats.draw, 100);
1787    }
1788}