Skip to main content

nuts_rs/
sampler.rs

1//! High-level sampler entry points: `Settings` presets, the parallel `Sampler`,
2//! and `sample_sequentially` for running one or many chains.
3
4use anyhow::{Context, Result, bail};
5use itertools::Itertools;
6use nuts_storable::{HasDims, Storable, Value};
7use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
8use rayon::{ScopeFifo, ThreadPoolBuilder};
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use std::{
11    collections::HashMap,
12    fmt::Debug,
13    ops::Deref,
14    sync::{
15        Arc, Mutex,
16        mpsc::{
17            Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
18        },
19    },
20    thread::{JoinHandle, spawn},
21    time::{Duration, Instant},
22};
23
24use crate::{
25    DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
26    adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
27    chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
28    dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
29    external_adapt_strategy::{ExternalTransformAdaptation, FlowSettings},
30    mclmc::MclmcTrajectoryKind,
31    model::Model,
32    nuts::NutsOptions,
33    sampler_stats::{SamplerStats, StatsDims},
34    storage::{ChainStorage, StorageConfig, TraceStorage},
35    transform::{
36        DiagAdaptStrategy, DiagMassMatrix, ExternalTransformation, LowRankMassMatrix,
37        LowRankMassMatrixStrategy, LowRankSettings,
38    },
39};
40
41/// All sampler configurations implement this trait
42pub trait Settings:
43    private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + DeserializeOwned + 'static
44{
45    type Chain<M: Math>: Chain<M>;
46
47    fn new_chain<M: Math, R: Rng + ?Sized>(
48        &self,
49        chain: u64,
50        math: M,
51        rng: &mut R,
52    ) -> Self::Chain<M>;
53
54    fn hint_num_tune(&self) -> usize;
55    fn hint_num_draws(&self) -> usize;
56    fn num_chains(&self) -> usize;
57    fn seed(&self) -> u64;
58    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
59    fn sampler_name(&self) -> &'static str;
60    fn adaptation_name(&self) -> &'static str;
61
62    fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
63        let dims = StatsDims::from(math);
64        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
65            .into_iter()
66            .map(String::from)
67            .collect()
68    }
69
70    fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
71        <M::ExpandedVector as Storable<_>>::names(math)
72            .into_iter()
73            .map(String::from)
74            .collect()
75    }
76
77    fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
78        self.stat_names(math)
79            .into_iter()
80            .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
81            .collect()
82    }
83
84    fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
85        let dims = StatsDims::from(math);
86        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
87    }
88
89    fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
90        self.data_names(math)
91            .into_iter()
92            .map(|name| (name.clone(), self.data_type(math, &name)))
93            .collect()
94    }
95    fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
96        <M::ExpandedVector as Storable<_>>::item_type(math, name)
97    }
98
99    fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
100        self.stat_names(math)
101            .into_iter()
102            .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
103            .collect()
104    }
105
106    fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
107        let dims = StatsDims::from(math);
108        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
109            .into_iter()
110            .map(String::from)
111            .collect()
112    }
113
114    fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
115        let dims = StatsDims::from(math);
116        dims.dim_sizes()
117    }
118
119    fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
120        self.data_names(math)
121            .into_iter()
122            .map(|name| (name.clone(), self.data_dims(math, &name)))
123            .collect()
124    }
125
126    fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
127        <M::ExpandedVector as Storable<_>>::dims(math, name)
128            .into_iter()
129            .map(String::from)
130            .collect()
131    }
132
133    fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
134        let dims = StatsDims::from(math);
135        dims.coords()
136    }
137
138    fn stat_event_dims<M: Math>(&self, math: &M) -> Vec<(String, Option<String>)> {
139        let dims = StatsDims::from(math);
140        self.stat_names(math)
141            .into_iter()
142            .map(|name| {
143                let event_dim =
144                    <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::event_dim(
145                        &dims, &name,
146                    )
147                    .map(String::from);
148                (name, event_dim)
149            })
150            .collect()
151    }
152}
153
154#[derive(Debug, Clone)]
155#[non_exhaustive]
156pub struct Progress {
157    pub draw: u64,
158    pub chain: u64,
159    pub diverging: bool,
160    pub tuning: bool,
161    pub step_size: f64,
162    pub num_steps: u64,
163}
164
165mod private {
166    use super::{
167        DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, FlowNutsSettings,
168        LowRankMclmcSettings, LowRankNutsSettings,
169    };
170
171    pub trait Sealed {}
172
173    impl Sealed for DiagNutsSettings {}
174
175    impl Sealed for LowRankNutsSettings {}
176
177    impl Sealed for FlowNutsSettings {}
178
179    impl Sealed for DiagMclmcSettings {}
180
181    impl Sealed for LowRankMclmcSettings {}
182
183    impl Sealed for FlowMclmcSettings {}
184}
185
186/// Settings for the NUTS sampler
187#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
188pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
189    /// The number of tuning steps, where we fit the step size and geometry.
190    pub num_tune: u64,
191    /// The number of draws after tuning
192    pub num_draws: u64,
193    /// The maximum tree depth during sampling. The number of leapfrog steps
194    /// is smaller than 2 ^ maxdepth.
195    pub maxdepth: u64,
196    /// The minimum tree depth during sampling. The number of leapfrog steps
197    /// is larger than 2 ^ mindepth.
198    pub mindepth: u64,
199    /// Store the gradient in the SampleStats
200    pub store_gradient: bool,
201    /// Store each unconstrained parameter vector in the sampler stats
202    pub store_unconstrained: bool,
203    /// Store the transformed gradient and value in the sampler stats
204    pub store_transformed: bool,
205    /// If the energy error is larger than this threshold we treat the leapfrog
206    /// step as a divergence.
207    pub max_energy_error: f64,
208    /// Store detailed information about each divergence in the sampler stats
209    pub store_divergences: bool,
210    /// Settings for geometry adaptation.
211    pub adapt_options: A,
212    pub check_turning: bool,
213    pub target_integration_time: Option<f64>,
214    /// Selects the kinetic-energy form and the corresponding integrator.
215    ///
216    /// - [`KineticEnergyKind::Euclidean`]: standard leapfrog (default for most settings).
217    /// - [`KineticEnergyKind::ExactNormal`]: geodesic leapfrog exact for a standard-normal
218    ///   potential.
219    /// - [`KineticEnergyKind::Microcanonical`]: isokinetic ESH-dynamics leapfrog (microcanonical
220    ///   HMC); momentum is constrained to the unit sphere.
221    pub trajectory_kind: KineticEnergyKind,
222    pub num_chains: usize,
223    pub seed: u64,
224    /// Number of extra doublings to perform after reaching maxdepth. This can
225    /// be used to increase the effective sample size at the cost of more
226    /// expensive sampling.
227    pub extra_doublings: u64,
228}
229
230pub type DiagNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
231/// Backwards-compatible alias for [`DiagNutsSettings`].
232#[deprecated(since = "0.0.0", note = "Use DiagNutsSettings instead")]
233pub type DiagGradNutsSettings = DiagNutsSettings;
234pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
235pub type FlowNutsSettings = NutsSettings<FlowSettings>;
236/// Backwards-compatible alias for [`FlowNutsSettings`].
237#[deprecated(since = "0.0.0", note = "Use FlowNutsSettings instead")]
238pub type TransformedNutsSettings = FlowNutsSettings;
239
240/// Settings for the unadjusted Microcanonical Langevin Monte Carlo (MCLMC) sampler.
241///
242/// > ⚠️ **Experimental — use with caution**: The MCLMC sampler and all of its
243/// > variants are highly experimental. They have not been thoroughly validated
244/// > and may **not return correct posteriors**. The API, defaults, and
245/// > adaptation behaviour are all subject to breaking changes at any time.
246/// > Do not use these samplers in production or for results you rely on.
247///
248/// Step size `ε` and momentum decoherence length `L` are **constants** — no
249/// adaptation of those is performed yet. The geometry is adapted during
250/// warmup using the sampler-specific adaptation strategy, while the step size
251/// remains fixed.
252///
253/// Use the type aliases [`DiagMclmcSettings`], [`LowRankMclmcSettings`], and
254/// [`FlowMclmcSettings`] for concrete configurations.
255#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
256pub struct MclmcSettings<A: Debug + Copy + Default + Serialize> {
257    /// Step size ε for the ESH leapfrog integrator.
258    pub step_size: f64,
259    /// Momentum decoherence length L (controls partial momentum refresh rate).
260    /// Set to `f64::INFINITY` to disable momentum refresh entirely.
261    pub momentum_decoherence_length: f64,
262    /// Number of warmup draws.
263    pub num_tune: u64,
264    /// Number of sampling draws after warmup.
265    pub num_draws: u64,
266    /// Number of parallel chains.
267    pub num_chains: usize,
268    /// RNG seed.
269    pub seed: u64,
270    /// Maximum energy error before a step is flagged as a divergence.
271    pub max_energy_error: f64,
272    /// Store each unconstrained parameter vector in the sampler stats.
273    pub store_unconstrained: bool,
274    /// Store the gradient in the sampler stats.
275    pub store_gradient: bool,
276    /// Store the transformed gradient and value in the sampler stats
277    pub store_transformed: bool,
278    /// Store detailed information about each divergence in the sampler stats
279    pub store_divergences: bool,
280    /// Geometry adaptation options (step-size fields are ignored for Euclidean settings).
281    pub adapt_options: A,
282    /// Number of leapfrog steps per draw as a fraction of `L / ε`.
283    ///
284    /// The number of leapfrog steps between collector calls is:
285    /// `round(subsample_frequency * L / ε).max(1)`
286    ///
287    /// - `1.0` (default) — one sample per full trajectory (at the final step).
288    /// - `0.0` — every leapfrog step.
289    /// - Values in between space samples as a fraction of the decoherence
290    ///   length, so the interval scales naturally when `L` or `ε` changes.
291    pub subsample_frequency: f64,
292    /// When `true`, use the tree-structured step size retry on divergence:
293    /// halve the step size factor and try 2 steps before doubling back.
294    /// `log_weight` will include `log(step_size)` to correct for the varying
295    /// sampling density. When `false`, divergences are recorded immediately
296    /// without any retry and `log_weight = -energy_change`.
297    pub dynamic_step_size: bool,
298    /// Selects which leapfrog integrator and partial-momentum-refresh style
299    /// to use.  See [`MclmcTrajectoryKind`] for the available options.
300    /// Default: [`MclmcTrajectoryKind::Microcanonical`] (original MCLMC).
301    pub trajectory_kind: MclmcTrajectoryKind,
302    /// Fraction of `num_tune` draws at which the trajectory is switched from
303    /// Euclidean to Microcanonical when
304    /// `trajectory_kind == MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical`.
305    /// Ignored for other trajectory kinds.  Default: `0.3`.
306    pub trajectory_switch_fraction: f64,
307}
308
309/// MCLMC settings with a diagonal mass matrix adaptation.
310///
311/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
312/// > of the returned posteriors has not been verified. May change at any time.
313pub type DiagMclmcSettings = MclmcSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
314/// MCLMC settings with a low-rank mass matrix adaptation.
315///
316/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
317/// > of the returned posteriors has not been verified. May change at any time.
318pub type LowRankMclmcSettings = MclmcSettings<EuclideanAdaptOptions<LowRankSettings>>;
319/// MCLMC settings with a learned flow transformation.
320///
321/// > ⚠️ **Experimental — use with caution**: Highly experimental. Correctness
322/// > of the returned posteriors has not been verified. May change at any time.
323pub type FlowMclmcSettings = MclmcSettings<FlowSettings>;
324/// Backwards-compatible alias for [`FlowMclmcSettings`].
325#[deprecated(since = "0.0.0", note = "Use FlowMclmcSettings instead")]
326pub type TransformedMclmcSettings = FlowMclmcSettings;
327
328fn usize_hint(value: u64, field: &str) -> usize {
329    value
330        .try_into()
331        .unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
332}
333
334fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
335    adapt_options: A,
336    num_tune: u64,
337    num_chains: usize,
338    max_energy_error: f64,
339) -> MclmcSettings<A> {
340    MclmcSettings {
341        step_size: 0.5,
342        momentum_decoherence_length: 3.0,
343        num_tune,
344        num_draws: 1000,
345        num_chains,
346        seed: 0,
347        max_energy_error,
348        store_unconstrained: false,
349        store_gradient: false,
350        store_divergences: false,
351        store_transformed: false,
352        adapt_options,
353        subsample_frequency: 1.0,
354        dynamic_step_size: true,
355        trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
356        trajectory_switch_fraction: 0.3,
357    }
358}
359
360impl Default for DiagMclmcSettings {
361    fn default() -> Self {
362        let mut adapt_options = EuclideanAdaptOptions::default();
363        adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
364        default_mclmc_settings(adapt_options, 400, 6, 1000.0)
365    }
366}
367
368impl Default for LowRankMclmcSettings {
369    fn default() -> Self {
370        let mut adapt_options = EuclideanAdaptOptions::default();
371        adapt_options.early_mass_matrix_switch_freq = 20;
372        adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
373        default_mclmc_settings(adapt_options, 800, 6, 1000.0)
374    }
375}
376
377impl Default for FlowMclmcSettings {
378    fn default() -> Self {
379        default_mclmc_settings(FlowSettings::default(), 1500, 1, 20.0)
380    }
381}
382
383type DiagMclmcChain<M> = crate::mclmc::MclmcChain<
384    M,
385    ChaCha8Rng,
386    GlobalStrategy<M, DiagAdaptStrategy<M>>,
387    DiagMassMatrix<M>,
388>;
389type LowRankMclmcChain<M> = crate::mclmc::MclmcChain<
390    M,
391    ChaCha8Rng,
392    GlobalStrategy<M, LowRankMassMatrixStrategy>,
393    LowRankMassMatrix<M>,
394>;
395
396impl Settings for DiagMclmcSettings {
397    type Chain<M: Math> = DiagMclmcChain<M>;
398
399    fn new_chain<M: Math, R: Rng + ?Sized>(
400        &self,
401        chain: u64,
402        mut math: M,
403        rng: &mut R,
404    ) -> Self::Chain<M> {
405        use crate::dynamics::KineticEnergyKind;
406        use crate::mclmc::MclmcChain;
407        use crate::stepsize::StepSizeAdaptMethod;
408
409        let num_tune = self.num_tune;
410        let mut adapt_options = self.adapt_options;
411        adapt_options.step_size_settings.adapt_options.method =
412            StepSizeAdaptMethod::Fixed(self.step_size);
413        let strategy = GlobalStrategy::<M, DiagAdaptStrategy<M>>::new(
414            &mut math,
415            adapt_options,
416            num_tune,
417            chain,
418        );
419        let mass_matrix = DiagMassMatrix::new(
420            &mut math,
421            self.adapt_options.mass_matrix_options.store_mass_matrix,
422        );
423        let initial_kind = match self.trajectory_kind {
424            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
425            MclmcTrajectoryKind::Euclidean
426            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
427        };
428        let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
429        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
430        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
431        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
432        let stats_options = self.stats_options::<M>();
433        MclmcChain::new(
434            math,
435            hamiltonian,
436            strategy,
437            rng,
438            chain,
439            self.subsample_frequency,
440            self.dynamic_step_size,
441            self.trajectory_kind,
442            switch_draw,
443            self.max_energy_error,
444            stats_options,
445        )
446    }
447
448    fn hint_num_tune(&self) -> usize {
449        usize_hint(self.num_tune, "num_tune")
450    }
451
452    fn hint_num_draws(&self) -> usize {
453        usize_hint(self.num_draws, "num_draws")
454    }
455
456    fn num_chains(&self) -> usize {
457        self.num_chains
458    }
459
460    fn seed(&self) -> u64 {
461        self.seed
462    }
463
464    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
465        StatOptions {
466            adapt: GlobalStrategyStatsOptions {
467                step_size: (),
468                mass_matrix: (),
469            },
470            hamiltonian: -1,
471            point: {
472                let store_gradient = self.store_gradient;
473                let store_unconstrained = self.store_unconstrained;
474                let store_transformed = self.store_transformed;
475                TransformedPointStatsOptions {
476                    store_gradient,
477                    store_unconstrained,
478                    store_transformed,
479                }
480            },
481            divergence: crate::dynamics::DivergenceStatsOptions {
482                store_divergences: self.store_divergences,
483            },
484        }
485    }
486
487    fn sampler_name(&self) -> &'static str {
488        "mclmc"
489    }
490
491    fn adaptation_name(&self) -> &'static str {
492        "diagonal"
493    }
494}
495
496fn default_nuts_settings<A: Debug + Copy + Default + Serialize>(
497    adapt_options: A,
498    num_tune: u64,
499    num_chains: usize,
500    max_energy_error: f64,
501) -> NutsSettings<A> {
502    NutsSettings {
503        num_tune,
504        num_draws: 1000,
505        maxdepth: 10,
506        mindepth: 0,
507        max_energy_error,
508        store_gradient: false,
509        store_unconstrained: false,
510        store_transformed: false,
511        store_divergences: false,
512        adapt_options,
513        check_turning: true,
514        seed: 0,
515        num_chains,
516        target_integration_time: None,
517        trajectory_kind: KineticEnergyKind::Euclidean,
518        extra_doublings: 0,
519    }
520}
521
522impl Settings for LowRankMclmcSettings {
523    type Chain<M: Math> = LowRankMclmcChain<M>;
524
525    fn new_chain<M: Math, R: Rng + ?Sized>(
526        &self,
527        chain: u64,
528        mut math: M,
529        rng: &mut R,
530    ) -> Self::Chain<M> {
531        use crate::dynamics::KineticEnergyKind;
532        use crate::mclmc::MclmcChain;
533        use crate::stepsize::StepSizeAdaptMethod;
534
535        let num_tune = self.num_tune;
536        let mut adapt_options = self.adapt_options;
537        adapt_options.step_size_settings.adapt_options.method =
538            StepSizeAdaptMethod::Fixed(self.step_size);
539        let strategy = GlobalStrategy::<M, LowRankMassMatrixStrategy>::new(
540            &mut math,
541            adapt_options,
542            num_tune,
543            chain,
544        );
545        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
546        let initial_kind = match self.trajectory_kind {
547            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
548            MclmcTrajectoryKind::Euclidean
549            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
550        };
551        let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
552        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
553        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
554        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
555        let stats_options = self.stats_options::<M>();
556        MclmcChain::new(
557            math,
558            hamiltonian,
559            strategy,
560            rng,
561            chain,
562            self.subsample_frequency,
563            self.dynamic_step_size,
564            self.trajectory_kind,
565            switch_draw,
566            self.max_energy_error,
567            stats_options,
568        )
569    }
570
571    fn hint_num_tune(&self) -> usize {
572        usize_hint(self.num_tune, "num_tune")
573    }
574
575    fn hint_num_draws(&self) -> usize {
576        usize_hint(self.num_draws, "num_draws")
577    }
578
579    fn num_chains(&self) -> usize {
580        self.num_chains
581    }
582
583    fn seed(&self) -> u64 {
584        self.seed
585    }
586
587    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
588        StatOptions {
589            adapt: GlobalStrategyStatsOptions {
590                step_size: (),
591                mass_matrix: (),
592            },
593            hamiltonian: -1,
594            point: {
595                let store_gradient = self.store_gradient;
596                let store_unconstrained = self.store_unconstrained;
597                let store_transformed = self.store_transformed;
598                TransformedPointStatsOptions {
599                    store_gradient,
600                    store_unconstrained,
601                    store_transformed,
602                }
603            },
604            divergence: crate::dynamics::DivergenceStatsOptions {
605                store_divergences: self.store_divergences,
606            },
607        }
608    }
609
610    fn sampler_name(&self) -> &'static str {
611        "mclmc"
612    }
613
614    fn adaptation_name(&self) -> &'static str {
615        "low_rank"
616    }
617}
618
619impl Default for DiagNutsSettings {
620    fn default() -> Self {
621        default_nuts_settings(EuclideanAdaptOptions::default(), 400, 6, 1000.0)
622    }
623}
624
625impl Default for LowRankNutsSettings {
626    fn default() -> Self {
627        let mut vals = default_nuts_settings(EuclideanAdaptOptions::default(), 800, 6, 1000.0);
628        vals.adapt_options.mass_matrix_update_freq = 20;
629        vals
630    }
631}
632
633impl Default for FlowNutsSettings {
634    fn default() -> Self {
635        default_nuts_settings(FlowSettings::default(), 1500, 1, 20.0)
636    }
637}
638
639type DiagNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, DiagAdaptStrategy<M>>>;
640type LowRankNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
641
642fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>) -> NutsOptions {
643    NutsOptions {
644        maxdepth: settings.maxdepth,
645        mindepth: settings.mindepth,
646        store_divergences: settings.store_divergences,
647        check_turning: settings.check_turning,
648        target_integration_time: settings.target_integration_time,
649        extra_doublings: settings.extra_doublings,
650        max_energy_error: settings.max_energy_error,
651    }
652}
653
654impl Settings for LowRankNutsSettings {
655    type Chain<M: Math> = LowRankNutsChain<M>;
656
657    fn new_chain<M: Math, R: Rng + ?Sized>(
658        &self,
659        chain: u64,
660        mut math: M,
661        mut rng: &mut R,
662    ) -> Self::Chain<M> {
663        let num_tune = self.num_tune;
664        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
665        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
666        let hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
667
668        let options = nuts_options(self);
669
670        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
671
672        NutsChain::new(
673            math,
674            hamiltonian,
675            strategy,
676            options,
677            rng,
678            chain,
679            self.stats_options(),
680        )
681    }
682
683    fn hint_num_tune(&self) -> usize {
684        usize_hint(self.num_tune, "num_tune")
685    }
686
687    fn hint_num_draws(&self) -> usize {
688        usize_hint(self.num_draws, "num_draws")
689    }
690
691    fn num_chains(&self) -> usize {
692        self.num_chains
693    }
694
695    fn seed(&self) -> u64 {
696        self.seed
697    }
698
699    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
700        StatOptions {
701            adapt: GlobalStrategyStatsOptions {
702                mass_matrix: (),
703                step_size: (),
704            },
705            hamiltonian: -1,
706            point: {
707                let store_gradient = self.store_gradient;
708                let store_unconstrained = self.store_unconstrained;
709                let store_transformed = self.store_transformed;
710                TransformedPointStatsOptions {
711                    store_gradient,
712                    store_unconstrained,
713                    store_transformed,
714                }
715            },
716            divergence: crate::dynamics::DivergenceStatsOptions {
717                store_divergences: self.store_divergences,
718            },
719        }
720    }
721
722    fn sampler_name(&self) -> &'static str {
723        "nuts"
724    }
725
726    fn adaptation_name(&self) -> &'static str {
727        "low_rank"
728    }
729}
730
731impl Settings for DiagNutsSettings {
732    type Chain<M: Math> = DiagNutsChain<M>;
733
734    fn new_chain<M: Math, R: Rng + ?Sized>(
735        &self,
736        chain: u64,
737        mut math: M,
738        mut rng: &mut R,
739    ) -> Self::Chain<M> {
740        let num_tune = self.num_tune;
741        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
742        let mass_matrix = DiagMassMatrix::new(
743            &mut math,
744            self.adapt_options.mass_matrix_options.store_mass_matrix,
745        );
746        let potential = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
747
748        let options = nuts_options(self);
749
750        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
751
752        NutsChain::new(
753            math,
754            potential,
755            strategy,
756            options,
757            rng,
758            chain,
759            self.stats_options(),
760        )
761    }
762
763    fn hint_num_tune(&self) -> usize {
764        usize_hint(self.num_tune, "num_tune")
765    }
766
767    fn hint_num_draws(&self) -> usize {
768        usize_hint(self.num_draws, "num_draws")
769    }
770
771    fn num_chains(&self) -> usize {
772        self.num_chains
773    }
774
775    fn seed(&self) -> u64 {
776        self.seed
777    }
778
779    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
780        StatOptions {
781            adapt: GlobalStrategyStatsOptions {
782                mass_matrix: (),
783                step_size: (),
784            },
785            hamiltonian: -1,
786            point: {
787                let store_gradient = self.store_gradient;
788                let store_unconstrained = self.store_unconstrained;
789                let store_transformed = self.store_transformed;
790                TransformedPointStatsOptions {
791                    store_gradient,
792                    store_unconstrained,
793                    store_transformed,
794                }
795            },
796            divergence: crate::dynamics::DivergenceStatsOptions {
797                store_divergences: self.store_divergences,
798            },
799        }
800    }
801
802    fn sampler_name(&self) -> &'static str {
803        "nuts"
804    }
805
806    fn adaptation_name(&self) -> &'static str {
807        "diagonal"
808    }
809}
810
811impl Settings for FlowNutsSettings {
812    type Chain<M: Math> = NutsChain<M, ChaCha8Rng, ExternalTransformAdaptation>;
813
814    fn new_chain<M: Math, R: Rng + ?Sized>(
815        &self,
816        chain: u64,
817        mut math: M,
818        mut rng: &mut R,
819    ) -> Self::Chain<M> {
820        let num_tune = self.num_tune;
821
822        let strategy =
823            ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
824        let params = math
825            .new_transformation(rng, math.dim(), chain)
826            .expect("Failed to create external transformation");
827        let transform = ExternalTransformation::new(params);
828        let hamiltonian = TransformedHamiltonian::new(&mut math, transform, self.trajectory_kind);
829
830        let options = nuts_options(self);
831
832        let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
833        NutsChain::new(
834            math,
835            hamiltonian,
836            strategy,
837            options,
838            rng,
839            chain,
840            self.stats_options(),
841        )
842    }
843
844    fn hint_num_tune(&self) -> usize {
845        usize_hint(self.num_tune, "num_tune")
846    }
847
848    fn hint_num_draws(&self) -> usize {
849        usize_hint(self.num_draws, "num_draws")
850    }
851
852    fn num_chains(&self) -> usize {
853        self.num_chains
854    }
855
856    fn seed(&self) -> u64 {
857        self.seed
858    }
859
860    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
861        StatOptions {
862            adapt: (),
863            hamiltonian: (),
864            point: {
865                let store_gradient = self.store_gradient;
866                let store_unconstrained = self.store_unconstrained;
867                let store_transformed = self.store_transformed;
868                TransformedPointStatsOptions {
869                    store_gradient,
870                    store_unconstrained,
871                    store_transformed,
872                }
873            },
874            divergence: crate::dynamics::DivergenceStatsOptions {
875                store_divergences: self.store_divergences,
876            },
877        }
878    }
879
880    fn sampler_name(&self) -> &'static str {
881        "nuts"
882    }
883
884    fn adaptation_name(&self) -> &'static str {
885        "flow"
886    }
887}
888
889impl Settings for FlowMclmcSettings {
890    type Chain<M: Math> = crate::mclmc::MclmcChain<
891        M,
892        ChaCha8Rng,
893        ExternalTransformAdaptation,
894        ExternalTransformation<M>,
895    >;
896
897    fn new_chain<M: Math, R: Rng + ?Sized>(
898        &self,
899        chain: u64,
900        mut math: M,
901        rng: &mut R,
902    ) -> Self::Chain<M> {
903        use crate::dynamics::KineticEnergyKind;
904        use crate::mclmc::MclmcChain;
905
906        let num_tune = self.num_tune;
907        let strategy =
908            ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
909        let params = math
910            .new_transformation(rng, math.dim(), chain)
911            .expect("Failed to create external transformation");
912        let transform = ExternalTransformation::new(params);
913        let initial_kind = match self.trajectory_kind {
914            MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
915            MclmcTrajectoryKind::Euclidean
916            | MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
917        };
918        let mut hamiltonian = TransformedHamiltonian::new(&mut math, transform, initial_kind);
919        hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
920        let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
921        let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
922        let stats_options = self.stats_options::<M>();
923        MclmcChain::new(
924            math,
925            hamiltonian,
926            strategy,
927            rng,
928            chain,
929            self.subsample_frequency,
930            self.dynamic_step_size,
931            self.trajectory_kind,
932            switch_draw,
933            self.max_energy_error,
934            stats_options,
935        )
936    }
937
938    fn hint_num_tune(&self) -> usize {
939        usize_hint(self.num_tune, "num_tune")
940    }
941
942    fn hint_num_draws(&self) -> usize {
943        usize_hint(self.num_draws, "num_draws")
944    }
945
946    fn num_chains(&self) -> usize {
947        self.num_chains
948    }
949
950    fn seed(&self) -> u64 {
951        self.seed
952    }
953
954    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
955        StatOptions {
956            adapt: (),
957            hamiltonian: (),
958            point: {
959                let store_gradient = self.store_gradient;
960                let store_unconstrained = self.store_unconstrained;
961                let store_transformed = self.store_transformed;
962                TransformedPointStatsOptions {
963                    store_gradient,
964                    store_unconstrained,
965                    store_transformed,
966                }
967            },
968            divergence: crate::dynamics::DivergenceStatsOptions {
969                store_divergences: self.store_divergences,
970            },
971        }
972    }
973
974    fn sampler_name(&self) -> &'static str {
975        "mclmc"
976    }
977
978    fn adaptation_name(&self) -> &'static str {
979        "flow"
980    }
981}
982
983pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
984    math: M,
985    settings: DiagNutsSettings,
986    start: &[f64],
987    draws: u64,
988    chain: u64,
989    rng: &mut R,
990) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
991    let mut sampler = settings.new_chain(chain, math, rng);
992    sampler.set_position(start)?;
993    Ok((0..draws).map(move |_| sampler.draw()))
994}
995
996#[non_exhaustive]
997#[derive(Clone, Debug)]
998pub struct ChainProgress {
999    pub finished_draws: usize,
1000    pub total_draws: usize,
1001    pub divergences: usize,
1002    pub tuning: bool,
1003    pub started: bool,
1004    pub latest_num_steps: usize,
1005    pub total_num_steps: usize,
1006    pub step_size: f64,
1007    pub runtime: Duration,
1008    pub divergent_draws: Vec<usize>,
1009}
1010
1011impl ChainProgress {
1012    fn new(total: usize) -> Self {
1013        Self {
1014            finished_draws: 0,
1015            total_draws: total,
1016            divergences: 0,
1017            tuning: true,
1018            started: false,
1019            latest_num_steps: 0,
1020            step_size: 0f64,
1021            total_num_steps: 0,
1022            runtime: Duration::ZERO,
1023            divergent_draws: Vec::new(),
1024        }
1025    }
1026
1027    fn update(&mut self, stats: &Progress, draw_duration: Duration) {
1028        if stats.diverging & !stats.tuning {
1029            self.divergences += 1;
1030            self.divergent_draws.push(self.finished_draws);
1031        }
1032        self.finished_draws += 1;
1033        self.tuning = stats.tuning;
1034
1035        self.latest_num_steps = stats.num_steps as usize;
1036        self.total_num_steps += stats.num_steps as usize;
1037        self.step_size = stats.step_size;
1038        self.runtime += draw_duration;
1039    }
1040}
1041
1042enum ChainCommand {
1043    Resume,
1044    Pause,
1045}
1046
1047struct ChainProcess<T>
1048where
1049    T: TraceStorage,
1050{
1051    stop_marker: Sender<ChainCommand>,
1052    trace: Arc<Mutex<Option<T::ChainStorage>>>,
1053    progress: Arc<Mutex<ChainProgress>>,
1054}
1055
1056impl<T: TraceStorage> ChainProcess<T> {
1057    fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
1058        let finalized_chain_traces = chains
1059            .into_iter()
1060            .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
1061            .map(|chain| chain.finalize())
1062            .collect_vec();
1063        trace.finalize(finalized_chain_traces)
1064    }
1065
1066    fn progress(&self) -> ChainProgress {
1067        self.progress.lock().expect("Poisoned lock").clone()
1068    }
1069
1070    fn resume(&self) -> Result<()> {
1071        self.stop_marker.send(ChainCommand::Resume)?;
1072        Ok(())
1073    }
1074
1075    fn pause(&self) -> Result<()> {
1076        self.stop_marker.send(ChainCommand::Pause)?;
1077        Ok(())
1078    }
1079
1080    fn start<'model, M: Model, S: Settings>(
1081        model: &'model M,
1082        chain_trace: T::ChainStorage,
1083        chain_id: u64,
1084        seed: u64,
1085        settings: &'model S,
1086        scope: &ScopeFifo<'model>,
1087        results: Sender<Result<()>>,
1088    ) -> Result<Self> {
1089        let (stop_marker_tx, stop_marker_rx) = channel();
1090
1091        let mut rng = ChaCha8Rng::seed_from_u64(seed);
1092        rng.set_stream(chain_id + 1);
1093
1094        let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
1095        let progress = Arc::new(Mutex::new(ChainProgress::new(
1096            settings.hint_num_draws() + settings.hint_num_tune(),
1097        )));
1098
1099        let trace_inner = chain_trace.clone();
1100        let progress_inner = progress.clone();
1101
1102        scope.spawn_fifo(move |_| {
1103            let chain_trace = trace_inner;
1104            let progress = progress_inner;
1105
1106            let mut sample = move || {
1107                let logp = model
1108                    .math(&mut rng)
1109                    .context("Failed to create model density")?;
1110                let dim = logp.dim();
1111
1112                let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
1113
1114                progress.lock().expect("Poisoned mutex").started = true;
1115
1116                let mut initval = vec![0f64; dim];
1117                // TODO maxtries
1118                let mut error = None;
1119                for _ in 0..500 {
1120                    model
1121                        .init_position(&mut rng, &mut initval)
1122                        .context("Failed to generate a new initial position")?;
1123                    if let Err(err) = sampler.set_position(&initval) {
1124                        error = Some(err);
1125                        continue;
1126                    }
1127                    error = None;
1128                    break;
1129                }
1130
1131                if let Some(error) = error {
1132                    return Err(error.context("All initialization points failed"));
1133                }
1134
1135                let draws = settings.hint_num_tune() + settings.hint_num_draws();
1136
1137                let mut msg = stop_marker_rx.try_recv();
1138                let mut draw = 0;
1139                loop {
1140                    match msg {
1141                        // The remote end is dead
1142                        Err(TryRecvError::Disconnected) => {
1143                            break;
1144                        }
1145                        Err(TryRecvError::Empty) => {}
1146                        Ok(ChainCommand::Pause) => {
1147                            msg = stop_marker_rx.recv().map_err(|e| e.into());
1148                            continue;
1149                        }
1150                        Ok(ChainCommand::Resume) => {}
1151                    }
1152
1153                    let now = Instant::now();
1154                    let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
1155
1156                    let mut guard = chain_trace
1157                        .lock()
1158                        .expect("Could not unlock trace lock. Poisoned mutex");
1159
1160                    let Some(trace_val) = guard.as_mut() else {
1161                        // The trace was removed by controller thread. We can stop sampling
1162                        break;
1163                    };
1164                    progress
1165                        .lock()
1166                        .expect("Poisoned mutex")
1167                        .update(&info, now.elapsed());
1168
1169                    let math = sampler.math();
1170                    let dims = StatsDims::from(math.deref());
1171                    trace_val.record_sample(
1172                        settings,
1173                        stats.get_all(&dims),
1174                        draw_data.get_all(math.deref()),
1175                        &info,
1176                    )?;
1177
1178                    draw += 1;
1179                    if draw == draws {
1180                        break;
1181                    }
1182
1183                    msg = stop_marker_rx.try_recv();
1184                }
1185                Ok(())
1186            };
1187
1188            let result = sample();
1189
1190            // We intentionally ignore errors here, because this means some other
1191            // chain already failed, and should have reported the error.
1192            let _ = results.send(result);
1193            drop(results);
1194        });
1195
1196        Ok(Self {
1197            trace: chain_trace,
1198            stop_marker: stop_marker_tx,
1199            progress,
1200        })
1201    }
1202
1203    fn flush(&self) -> Result<()> {
1204        self.trace
1205            .lock()
1206            .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
1207            .context("Could not flush trace")?
1208            .as_mut()
1209            .map(|v| v.flush())
1210            .transpose()?;
1211        Ok(())
1212    }
1213}
1214
1215#[derive(Debug)]
1216enum SamplerCommand {
1217    Pause,
1218    Continue,
1219    Progress,
1220    Flush,
1221    Inspect,
1222}
1223
1224enum SamplerResponse<T: Send + 'static> {
1225    Ok(),
1226    Progress(Box<[ChainProgress]>),
1227    Inspect(T),
1228}
1229
1230pub enum SamplerWaitResult<F: Send + 'static> {
1231    Trace(F),
1232    Timeout(Sampler<F>),
1233    Err(anyhow::Error, Option<F>),
1234}
1235
1236pub struct Sampler<F: Send + 'static> {
1237    main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
1238    commands: SyncSender<SamplerCommand>,
1239    responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
1240    results: Receiver<Result<()>>,
1241}
1242
1243pub struct ProgressCallback {
1244    pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
1245    pub rate: Duration,
1246}
1247
1248impl<F: Send + 'static> Sampler<F> {
1249    pub fn new<M, S, C, T>(
1250        model: M,
1251        settings: S,
1252        trace_config: C,
1253        num_cores: usize,
1254        callback: Option<ProgressCallback>,
1255    ) -> Result<Self>
1256    where
1257        S: Settings,
1258        C: StorageConfig<Storage = T>,
1259        M: Model,
1260        T: TraceStorage<Finalized = F>,
1261    {
1262        let (commands_tx, commands_rx) = sync_channel(0);
1263        let (responses_tx, responses_rx) = sync_channel(0);
1264        let (results_tx, results_rx) = channel();
1265
1266        let main_thread = spawn(move || {
1267            let pool = ThreadPoolBuilder::new()
1268                .num_threads(num_cores + 1) // One more thread because the controller also uses one
1269                .thread_name(|i| format!("nutpie-worker-{i}"))
1270                .build()
1271                .context("Could not start thread pool")?;
1272
1273            let settings_ref = &settings;
1274            let model_ref = &model;
1275            let mut callback = callback;
1276
1277            pool.scope_fifo(move |scope| {
1278                let results = results_tx;
1279                let mut chains = Vec::with_capacity(settings.num_chains());
1280
1281                let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
1282                rng.set_stream(0);
1283
1284                let math = model_ref
1285                    .math(&mut rng)
1286                    .context("Could not create model density")?;
1287                let trace = trace_config
1288                    .new_trace(settings_ref, &math)
1289                    .context("Could not create trace object")?;
1290                drop(math);
1291
1292                for chain_id in 0..settings.num_chains() {
1293                    let chain_trace_val = trace
1294                        .initialize_trace_for_chain(chain_id as u64)
1295                        .context("Failed to create trace object")?;
1296                    let chain = ChainProcess::start(
1297                        model_ref,
1298                        chain_trace_val,
1299                        chain_id as u64,
1300                        settings.seed(),
1301                        settings_ref,
1302                        scope,
1303                        results.clone(),
1304                    );
1305                    chains.push(chain);
1306                }
1307                drop(results);
1308
1309                let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
1310                if let Some(error) = errors.into_iter().next() {
1311                    let _ = ChainProcess::finalize_many(trace, chains);
1312                    return Err(error).context("Could not start chains");
1313                }
1314
1315                let mut main_loop = || {
1316                    let start_time = Instant::now();
1317                    let mut pause_start = Instant::now();
1318                    let mut pause_time = Duration::ZERO;
1319
1320                    let mut progress_rate = Duration::MAX;
1321                    if let Some(ProgressCallback { callback, rate }) = &mut callback {
1322                        let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
1323                        callback(start_time.elapsed(), progress.into());
1324                        progress_rate = *rate;
1325                    }
1326                    let mut last_progress = Instant::now();
1327                    let mut is_paused = false;
1328
1329                    loop {
1330                        let timeout = progress_rate.checked_sub(last_progress.elapsed());
1331                        let timeout = timeout.unwrap_or_else(|| {
1332                            if let Some(ProgressCallback { callback, .. }) = &mut callback {
1333                                let progress =
1334                                    chains.iter().map(|chain| chain.progress()).collect_vec();
1335                                let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
1336                                if is_paused {
1337                                    elapsed = elapsed.saturating_sub(pause_start.elapsed());
1338                                }
1339                                callback(elapsed, progress.into());
1340                            }
1341                            last_progress = Instant::now();
1342                            progress_rate
1343                        });
1344
1345                        // TODO return when all chains are done
1346                        match commands_rx.recv_timeout(timeout) {
1347                            Ok(SamplerCommand::Pause) => {
1348                                for chain in chains.iter() {
1349                                    // This failes if the thread is done.
1350                                    // We just want to ignore those threads.
1351                                    let _ = chain.pause();
1352                                }
1353                                if !is_paused {
1354                                    pause_start = Instant::now();
1355                                }
1356                                is_paused = true;
1357                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1358                                    anyhow::anyhow!(
1359                                        "Could not send pause response to controller thread: {e}"
1360                                    )
1361                                })?;
1362                            }
1363                            Ok(SamplerCommand::Continue) => {
1364                                for chain in chains.iter() {
1365                                    // This failes if the thread is done.
1366                                    // We just want to ignore those threads.
1367                                    let _ = chain.resume();
1368                                }
1369                                pause_time += pause_start.elapsed();
1370                                is_paused = false;
1371                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1372                                    anyhow::anyhow!(
1373                                        "Could not send continue response to controller thread: {e}"
1374                                    )
1375                                })?;
1376                            }
1377                            Ok(SamplerCommand::Progress) => {
1378                                let progress =
1379                                    chains.iter().map(|chain| chain.progress()).collect_vec();
1380                                responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
1381                                    anyhow::anyhow!(
1382                                        "Could not send progress response to controller thread: {e}"
1383                                    )
1384                                })?;
1385                            }
1386                            Ok(SamplerCommand::Inspect) => {
1387                                let traces = chains
1388                                    .iter()
1389                                    .filter_map(|chain| {
1390                                        chain
1391                                            .trace
1392                                            .lock()
1393                                            .expect("Poisoned lock")
1394                                            .as_ref()
1395                                            .map(|v| v.inspect())
1396                                    })
1397                                    .collect_vec();
1398                                let finalized_trace = trace.inspect(traces)?;
1399                                responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
1400                                    anyhow::anyhow!(
1401                                        "Could not send inspect response to controller thread: {e}"
1402                                    )
1403                                })?;
1404                            }
1405                            Ok(SamplerCommand::Flush) => {
1406                                for chain in chains.iter() {
1407                                    chain.flush()?;
1408                                }
1409                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
1410                                    anyhow::anyhow!(
1411                                        "Could not send flush response to controller thread: {e}"
1412                                    )
1413                                })?;
1414                            }
1415                            Err(RecvTimeoutError::Timeout) => {}
1416                            Err(RecvTimeoutError::Disconnected) => {
1417                                if let Some(ProgressCallback { callback, .. }) = &mut callback {
1418                                    let progress =
1419                                        chains.iter().map(|chain| chain.progress()).collect_vec();
1420                                    let mut elapsed =
1421                                        start_time.elapsed().saturating_sub(pause_time);
1422                                    if is_paused {
1423                                        elapsed = elapsed.saturating_sub(pause_start.elapsed());
1424                                    }
1425                                    callback(elapsed, progress.into());
1426                                }
1427                                return Ok(());
1428                            }
1429                        };
1430                    }
1431                };
1432                let result: Result<()> = main_loop();
1433                // Run finalization even if something failed
1434                let output = ChainProcess::finalize_many(trace, chains)?;
1435
1436                result?;
1437                Ok(output)
1438            })
1439        });
1440
1441        Ok(Self {
1442            main_thread,
1443            commands: commands_tx,
1444            responses: responses_rx,
1445            results: results_rx,
1446        })
1447    }
1448
1449    pub fn pause(&mut self) -> Result<()> {
1450        self.commands
1451            .send(SamplerCommand::Pause)
1452            .context("Could not send pause command to controller thread")?;
1453        let response = self
1454            .responses
1455            .recv()
1456            .context("Could not recieve pause response from controller thread")?;
1457        let SamplerResponse::Ok() = response else {
1458            bail!("Got invalid response from sample controller thread");
1459        };
1460        Ok(())
1461    }
1462
1463    pub fn resume(&mut self) -> Result<()> {
1464        self.commands.send(SamplerCommand::Continue)?;
1465        let response = self.responses.recv()?;
1466        let SamplerResponse::Ok() = response else {
1467            bail!("Got invalid response from sample controller thread");
1468        };
1469        Ok(())
1470    }
1471
1472    pub fn flush(&mut self) -> Result<()> {
1473        self.commands.send(SamplerCommand::Flush)?;
1474        let response = self
1475            .responses
1476            .recv()
1477            .context("Could not recieve flush response from controller thread")?;
1478        let SamplerResponse::Ok() = response else {
1479            bail!("Got invalid response from sample controller thread");
1480        };
1481        Ok(())
1482    }
1483
1484    pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
1485        self.commands.send(SamplerCommand::Inspect)?;
1486        let response = self
1487            .responses
1488            .recv()
1489            .context("Could not recieve inspect response from controller thread")?;
1490        let SamplerResponse::Inspect(trace) = response else {
1491            bail!("Got invalid response from sample controller thread");
1492        };
1493        Ok(trace)
1494    }
1495
1496    pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
1497        drop(self.commands);
1498        let result = self.main_thread.join();
1499        match result {
1500            Err(payload) => std::panic::resume_unwind(payload),
1501            Ok(Ok(val)) => Ok(val),
1502            Ok(Err(err)) => Err(err),
1503        }
1504    }
1505
1506    pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
1507        let start = Instant::now();
1508        let mut remaining = Some(timeout);
1509        while remaining.is_some() {
1510            match self.results.recv_timeout(timeout) {
1511                Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
1512                Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
1513                Err(RecvTimeoutError::Disconnected) => match self.abort() {
1514                    Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
1515                    Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
1516                    Err(err) => return SamplerWaitResult::Err(err, None),
1517                },
1518                Err(RecvTimeoutError::Timeout) => break,
1519            }
1520        }
1521        SamplerWaitResult::Timeout(self)
1522    }
1523
1524    pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1525        self.commands.send(SamplerCommand::Progress)?;
1526        let response = self.responses.recv()?;
1527        let SamplerResponse::Progress(progress) = response else {
1528            bail!("Got invalid response from sample controller thread");
1529        };
1530        Ok(progress)
1531    }
1532}
1533
1534#[cfg(test)]
1535pub mod test_logps {
1536
1537    use std::collections::HashMap;
1538
1539    use crate::math::{CpuLogpFunc, LogpError};
1540    #[cfg(feature = "zarr")]
1541    use crate::{Model, math::CpuMath};
1542    use anyhow::Result;
1543    use nuts_storable::HasDims;
1544    #[cfg(feature = "zarr")]
1545    use rand::Rng;
1546    use thiserror::Error;
1547
1548    #[derive(Clone, Debug)]
1549    pub struct NormalLogp {
1550        pub dim: usize,
1551        pub mu: f64,
1552    }
1553
1554    #[derive(Error, Debug)]
1555    pub enum NormalLogpError {}
1556
1557    impl LogpError for NormalLogpError {
1558        fn is_recoverable(&self) -> bool {
1559            false
1560        }
1561    }
1562
1563    impl HasDims for &NormalLogp {
1564        fn dim_sizes(&self) -> HashMap<String, u64> {
1565            vec![
1566                ("unconstrained_parameter".to_string(), self.dim as u64),
1567                ("dim".to_string(), self.dim as u64),
1568            ]
1569            .into_iter()
1570            .collect()
1571        }
1572    }
1573
1574    impl CpuLogpFunc for &NormalLogp {
1575        type LogpError = NormalLogpError;
1576        type FlowParameters = ();
1577        type ExpandedVector = Vec<f64>;
1578
1579        fn dim(&self) -> usize {
1580            self.dim
1581        }
1582
1583        fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
1584            let n = position.len();
1585            assert!(gradient.len() == n);
1586
1587            let mut logp = 0f64;
1588            for (p, g) in position.iter().zip(gradient.iter_mut()) {
1589                let val = self.mu - p;
1590                logp -= val * val / 2.;
1591                *g = val;
1592            }
1593
1594            Ok(logp)
1595        }
1596
1597        fn expand_vector<R>(
1598            &mut self,
1599            _rng: &mut R,
1600            array: &[f64],
1601        ) -> std::result::Result<Self::ExpandedVector, crate::math::CpuMathError>
1602        where
1603            R: rand::Rng + ?Sized,
1604        {
1605            Ok(array.to_vec())
1606        }
1607
1608        fn inv_transform_normalize(
1609            &mut self,
1610            _params: &Self::FlowParameters,
1611            _untransformed_position: &[f64],
1612            _untransofrmed_gradient: &[f64],
1613            _transformed_position: &mut [f64],
1614            _transformed_gradient: &mut [f64],
1615        ) -> std::result::Result<f64, Self::LogpError> {
1616            unimplemented!()
1617        }
1618
1619        fn init_from_untransformed_position(
1620            &mut self,
1621            _params: &Self::FlowParameters,
1622            _untransformed_position: &[f64],
1623            _untransformed_gradient: &mut [f64],
1624            _transformed_position: &mut [f64],
1625            _transformed_gradient: &mut [f64],
1626        ) -> std::result::Result<(f64, f64), Self::LogpError> {
1627            unimplemented!()
1628        }
1629
1630        fn init_from_transformed_position(
1631            &mut self,
1632            _params: &Self::FlowParameters,
1633            _untransformed_position: &mut [f64],
1634            _untransformed_gradient: &mut [f64],
1635            _transformed_position: &[f64],
1636            _transformed_gradient: &mut [f64],
1637        ) -> std::result::Result<(f64, f64), Self::LogpError> {
1638            unimplemented!()
1639        }
1640
1641        fn update_transformation<'b, R: rand::Rng + ?Sized>(
1642            &'b mut self,
1643            _rng: &mut R,
1644            _untransformed_positions: impl Iterator<Item = &'b [f64]>,
1645            _untransformed_gradients: impl Iterator<Item = &'b [f64]>,
1646            _untransformed_logp: impl Iterator<Item = &'b f64>,
1647            _params: &'b mut Self::FlowParameters,
1648        ) -> std::result::Result<(), Self::LogpError> {
1649            unimplemented!()
1650        }
1651
1652        fn init_transformation<R: rand::Rng + ?Sized>(
1653            &mut self,
1654            _rng: &mut R,
1655            _untransformed_position: &[f64],
1656            _untransfogmed_gradient: &[f64],
1657            _chain: u64,
1658        ) -> std::result::Result<Self::FlowParameters, Self::LogpError> {
1659            unimplemented!()
1660        }
1661
1662        fn transformation_id(
1663            &self,
1664            _params: &Self::FlowParameters,
1665        ) -> std::result::Result<i64, Self::LogpError> {
1666            unimplemented!()
1667        }
1668    }
1669
1670    #[cfg(feature = "zarr")]
1671    pub struct CpuModel<F> {
1672        logp: F,
1673    }
1674
1675    #[cfg(feature = "zarr")]
1676    impl<F> CpuModel<F> {
1677        pub fn new(logp: F) -> Self {
1678            Self { logp }
1679        }
1680    }
1681
1682    #[cfg(feature = "zarr")]
1683    impl<F> Model for CpuModel<F>
1684    where
1685        F: Send + Sync + 'static,
1686        for<'a> &'a F: CpuLogpFunc,
1687    {
1688        type Math<'model> = CpuMath<&'model F>;
1689
1690        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1691            Ok(CpuMath::new(&self.logp))
1692        }
1693
1694        fn init_position<R: rand::prelude::Rng + ?Sized>(
1695            &self,
1696            _rng: &mut R,
1697            position: &mut [f64],
1698        ) -> Result<()> {
1699            position.iter_mut().for_each(|x| *x = 0.);
1700            Ok(())
1701        }
1702    }
1703}
1704
1705#[cfg(test)]
1706mod tests {
1707    use super::test_logps::NormalLogp;
1708    use crate::{
1709        Chain, math::CpuMath, sample_sequentially, sampler::DiagMclmcSettings,
1710        sampler::DiagNutsSettings, sampler::LowRankMclmcSettings, sampler::LowRankNutsSettings,
1711        sampler::Settings,
1712    };
1713
1714    #[cfg(feature = "zarr")]
1715    use super::test_logps::CpuModel;
1716
1717    use anyhow::Result;
1718    use itertools::Itertools;
1719    use pretty_assertions::assert_eq;
1720    use rand::{SeedableRng, rngs::StdRng};
1721
1722    #[cfg(feature = "zarr")]
1723    use std::{
1724        sync::Arc,
1725        time::{Duration, Instant},
1726    };
1727
1728    #[cfg(feature = "zarr")]
1729    use crate::{Sampler, ZarrConfig};
1730
1731    #[cfg(feature = "zarr")]
1732    use zarrs::storage::store::MemoryStore;
1733
1734    fn assert_settings_smoke<S: Settings>(settings: S) -> Result<()> {
1735        let logp = NormalLogp { dim: 4, mu: 0.1 };
1736        let math = CpuMath::new(&logp);
1737        let mut rng = StdRng::seed_from_u64(42);
1738
1739        let stat_names = settings.stat_names(&math);
1740        let stat_types = settings.stat_types(&math);
1741        assert!(!stat_names.is_empty());
1742        assert_eq!(stat_names.len(), stat_types.len());
1743
1744        let mut chain = settings.new_chain(0, math, &mut rng);
1745        chain.set_position(&vec![0.2; 4])?;
1746        let (_draw, _info) = chain.draw()?;
1747        Ok(())
1748    }
1749
1750    #[test]
1751    fn all_settings_smoke() -> Result<()> {
1752        assert_settings_smoke(DiagNutsSettings {
1753            num_tune: 10,
1754            num_draws: 10,
1755            ..Default::default()
1756        })?;
1757        assert_settings_smoke(LowRankNutsSettings {
1758            num_tune: 10,
1759            num_draws: 10,
1760            ..Default::default()
1761        })?;
1762        assert_settings_smoke(DiagMclmcSettings {
1763            num_tune: 10,
1764            num_draws: 10,
1765            ..Default::default()
1766        })?;
1767        assert_settings_smoke(LowRankMclmcSettings {
1768            num_tune: 10,
1769            num_draws: 10,
1770            ..Default::default()
1771        })?;
1772        Ok(())
1773    }
1774
1775    #[test]
1776    fn sample_chain() -> Result<()> {
1777        let logp = NormalLogp { dim: 10, mu: 0.1 };
1778        let math = CpuMath::new(&logp);
1779        let settings = DiagNutsSettings {
1780            num_tune: 100,
1781            num_draws: 100,
1782            ..Default::default()
1783        };
1784        let start = vec![0.2; 10];
1785
1786        let mut rng = StdRng::seed_from_u64(42);
1787
1788        let mut chain = settings.new_chain(0, math, &mut rng);
1789
1790        let (_draw, info) = chain.draw()?;
1791        assert!(info.tuning);
1792        assert_eq!(info.draw, 0);
1793
1794        let math = CpuMath::new(&logp);
1795        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1796        let mut draws = chain.collect_vec();
1797        assert_eq!(draws.len(), 200);
1798
1799        let draw0 = draws.remove(100).unwrap();
1800        let (vals, stats) = draw0;
1801        assert_eq!(vals.len(), 10);
1802        assert_eq!(stats.chain, 1);
1803        assert_eq!(stats.draw, 100);
1804        Ok(())
1805    }
1806
1807    #[cfg(feature = "zarr")]
1808    #[test]
1809    fn sample_parallel() -> Result<()> {
1810        let logp = NormalLogp { dim: 100, mu: 0.1 };
1811        let settings = DiagNutsSettings {
1812            num_tune: 100,
1813            num_draws: 100,
1814            seed: 10,
1815            ..Default::default()
1816        };
1817
1818        let model = CpuModel::new(logp.clone());
1819        let store = MemoryStore::new();
1820
1821        let zarr_config = ZarrConfig::new(Arc::new(store));
1822        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1823        sampler.pause()?;
1824        sampler.pause()?;
1825        // TODO flush trace
1826        sampler.resume()?;
1827        let (ok, _) = sampler.abort()?;
1828        if let Some(err) = ok {
1829            Err(err)?;
1830        }
1831
1832        let store = MemoryStore::new();
1833        let zarr_config = ZarrConfig::new(Arc::new(store));
1834        let model = CpuModel::new(logp.clone());
1835        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1836        sampler.pause()?;
1837        if let (Some(err), _) = sampler.abort()? {
1838            Err(err)?;
1839        }
1840
1841        let store = MemoryStore::new();
1842        let zarr_config = ZarrConfig::new(Arc::new(store));
1843        let model = CpuModel::new(logp.clone());
1844        let start = Instant::now();
1845        let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1846
1847        let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1848            super::SamplerWaitResult::Trace(_) => {
1849                dbg!(start.elapsed());
1850                panic!("finished");
1851            }
1852            super::SamplerWaitResult::Timeout(sampler) => sampler,
1853            super::SamplerWaitResult::Err(_, _) => {
1854                panic!("error")
1855            }
1856        };
1857
1858        for _ in 0..30 {
1859            sampler.progress()?;
1860        }
1861
1862        match sampler.wait_timeout(Duration::from_secs(1)) {
1863            super::SamplerWaitResult::Trace(_) => {
1864                dbg!(start.elapsed());
1865            }
1866            super::SamplerWaitResult::Timeout(_) => {
1867                panic!("timeout")
1868            }
1869            super::SamplerWaitResult::Err(err, _) => Err(err)?,
1870        };
1871
1872        Ok(())
1873    }
1874
1875    #[test]
1876    fn sample_seq() {
1877        let logp = NormalLogp { dim: 10, mu: 0.1 };
1878        let math = CpuMath::new(&logp);
1879        let settings = DiagNutsSettings {
1880            num_tune: 100,
1881            num_draws: 100,
1882            ..Default::default()
1883        };
1884        let start = vec![0.2; 10];
1885
1886        let mut rng = StdRng::seed_from_u64(42);
1887
1888        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1889        let mut draws = chain.collect_vec();
1890        assert_eq!(draws.len(), 200);
1891
1892        let draw0 = draws.remove(100).unwrap();
1893        let (vals, stats) = draw0;
1894        assert_eq!(vals.len(), 10);
1895        assert_eq!(stats.chain, 1);
1896        assert_eq!(stats.draw, 100);
1897    }
1898}