Skip to main content

nuts_rs/
sampler.rs

1use anyhow::{Context, Result, bail};
2use itertools::Itertools;
3use nuts_storable::{HasDims, Storable, Value};
4use rand::{Rng, SeedableRng, rngs::ChaCha8Rng, rngs::SmallRng};
5use rayon::{ScopeFifo, ThreadPoolBuilder};
6use serde::Serialize;
7use std::{
8    collections::HashMap,
9    fmt::Debug,
10    ops::Deref,
11    sync::{
12        Arc, Mutex,
13        mpsc::{
14            Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
15        },
16    },
17    thread::{JoinHandle, spawn},
18    time::{Duration, Instant},
19};
20
21use crate::{
22    DiagAdaptExpSettings,
23    adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
24    chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
25    euclidean_hamiltonian::EuclideanHamiltonian,
26    mass_matrix::DiagMassMatrix,
27    mass_matrix::Strategy as DiagMassMatrixStrategy,
28    mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
29    math_base::Math,
30    model::Model,
31    nuts::NutsOptions,
32    sampler_stats::{SamplerStats, StatsDims},
33    storage::{ChainStorage, StorageConfig, TraceStorage},
34    transform_adapt_strategy::{TransformAdaptation, TransformedSettings},
35    transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions},
36};
37
38/// All sampler configurations implement this trait
39pub trait Settings:
40    private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + 'static
41{
42    type Chain<M: Math>: Chain<M>;
43
44    fn new_chain<M: Math, R: Rng + ?Sized>(
45        &self,
46        chain: u64,
47        math: M,
48        rng: &mut R,
49    ) -> Self::Chain<M>;
50
51    fn hint_num_tune(&self) -> usize;
52    fn hint_num_draws(&self) -> usize;
53    fn num_chains(&self) -> usize;
54    fn seed(&self) -> u64;
55    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
56
57    fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
58        let dims = StatsDims::from(math);
59        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
60            .into_iter()
61            .map(String::from)
62            .collect()
63    }
64
65    fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
66        <M::ExpandedVector as Storable<_>>::names(math)
67            .into_iter()
68            .map(String::from)
69            .collect()
70    }
71
72    fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
73        self.stat_names(math)
74            .into_iter()
75            .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
76            .collect()
77    }
78
79    fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
80        let dims = StatsDims::from(math);
81        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
82    }
83
84    fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
85        self.data_names(math)
86            .into_iter()
87            .map(|name| (name.clone(), self.data_type(math, &name)))
88            .collect()
89    }
90    fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
91        <M::ExpandedVector as Storable<_>>::item_type(math, name)
92    }
93
94    fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
95        self.stat_names(math)
96            .into_iter()
97            .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
98            .collect()
99    }
100
101    fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
102        let dims = StatsDims::from(math);
103        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
104            .into_iter()
105            .map(String::from)
106            .collect()
107    }
108
109    fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
110        let dims = StatsDims::from(math);
111        dims.dim_sizes()
112    }
113
114    fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
115        self.data_names(math)
116            .into_iter()
117            .map(|name| (name.clone(), self.data_dims(math, &name)))
118            .collect()
119    }
120
121    fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
122        <M::ExpandedVector as Storable<_>>::dims(math, name)
123            .into_iter()
124            .map(String::from)
125            .collect()
126    }
127
128    fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
129        let dims = StatsDims::from(math);
130        dims.coords()
131    }
132}
133
134#[derive(Debug, Clone)]
135#[non_exhaustive]
136pub struct Progress {
137    pub draw: u64,
138    pub chain: u64,
139    pub diverging: bool,
140    pub tuning: bool,
141    pub step_size: f64,
142    pub num_steps: u64,
143}
144
145mod private {
146    use crate::DiagGradNutsSettings;
147
148    use super::{LowRankNutsSettings, TransformedNutsSettings};
149
150    pub trait Sealed {}
151
152    impl Sealed for DiagGradNutsSettings {}
153
154    impl Sealed for LowRankNutsSettings {}
155
156    impl Sealed for TransformedNutsSettings {}
157}
158
159/// Settings for the NUTS sampler
160#[derive(Debug, Clone, Copy, Serialize)]
161pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
162    /// The number of tuning steps, where we fit the step size and mass matrix.
163    pub num_tune: u64,
164    /// The number of draws after tuning
165    pub num_draws: u64,
166    /// The maximum tree depth during sampling. The number of leapfrog steps
167    /// is smaller than 2 ^ maxdepth.
168    pub maxdepth: u64,
169    /// The minimum tree depth during sampling. The number of leapfrog steps
170    /// is larger than 2 ^ mindepth.
171    pub mindepth: u64,
172    /// Store the gradient in the SampleStats
173    pub store_gradient: bool,
174    /// Store each unconstrained parameter vector in the sampler stats
175    pub store_unconstrained: bool,
176    /// If the energy error is larger than this threshold we treat the leapfrog
177    /// step as a divergence.
178    pub max_energy_error: f64,
179    /// Store detailed information about each divergence in the sampler stats
180    pub store_divergences: bool,
181    /// Settings for mass matrix adaptation.
182    pub adapt_options: A,
183    pub check_turning: bool,
184
185    pub num_chains: usize,
186    pub seed: u64,
187}
188
189pub type DiagGradNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
190pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
191pub type TransformedNutsSettings = NutsSettings<TransformedSettings>;
192
193impl Default for DiagGradNutsSettings {
194    fn default() -> Self {
195        Self {
196            num_tune: 400,
197            num_draws: 1000,
198            maxdepth: 10,
199            mindepth: 0,
200            max_energy_error: 1000f64,
201            store_gradient: false,
202            store_unconstrained: false,
203            store_divergences: false,
204            adapt_options: EuclideanAdaptOptions::default(),
205            check_turning: true,
206            seed: 0,
207            num_chains: 6,
208        }
209    }
210}
211
212impl Default for LowRankNutsSettings {
213    fn default() -> Self {
214        let mut vals = Self {
215            num_tune: 800,
216            num_draws: 1000,
217            maxdepth: 10,
218            mindepth: 0,
219            max_energy_error: 1000f64,
220            store_gradient: false,
221            store_unconstrained: false,
222            store_divergences: false,
223            adapt_options: EuclideanAdaptOptions::default(),
224            check_turning: true,
225            seed: 0,
226            num_chains: 6,
227        };
228        vals.adapt_options.mass_matrix_update_freq = 10;
229        vals
230    }
231}
232
233impl Default for TransformedNutsSettings {
234    fn default() -> Self {
235        Self {
236            num_tune: 1500,
237            num_draws: 1000,
238            maxdepth: 10,
239            mindepth: 0,
240            max_energy_error: 20f64,
241            store_gradient: false,
242            store_unconstrained: false,
243            store_divergences: false,
244            adapt_options: Default::default(),
245            check_turning: true,
246            seed: 0,
247            num_chains: 1,
248        }
249    }
250}
251
252type DiagGradNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, DiagMassMatrixStrategy<M>>>;
253
254type LowRankNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
255
256type TransformingNutsChain<M> = NutsChain<M, SmallRng, TransformAdaptation>;
257
258impl Settings for LowRankNutsSettings {
259    type Chain<M: Math> = LowRankNutsChain<M>;
260
261    fn new_chain<M: Math, R: Rng + ?Sized>(
262        &self,
263        chain: u64,
264        mut math: M,
265        mut rng: &mut R,
266    ) -> Self::Chain<M> {
267        let num_tune = self.num_tune;
268        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
269        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
270        let max_energy_error = self.max_energy_error;
271        let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
272
273        let options = NutsOptions {
274            maxdepth: self.maxdepth,
275            mindepth: self.mindepth,
276            store_gradient: self.store_gradient,
277            store_divergences: self.store_divergences,
278            store_unconstrained: self.store_unconstrained,
279            check_turning: self.check_turning,
280        };
281
282        let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
283
284        NutsChain::new(
285            math,
286            potential,
287            strategy,
288            options,
289            rng,
290            chain,
291            self.stats_options(),
292        )
293    }
294
295    fn hint_num_tune(&self) -> usize {
296        self.num_tune as _
297    }
298
299    fn hint_num_draws(&self) -> usize {
300        self.num_draws as _
301    }
302
303    fn num_chains(&self) -> usize {
304        self.num_chains
305    }
306
307    fn seed(&self) -> u64 {
308        self.seed
309    }
310
311    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
312        StatOptions {
313            adapt: GlobalStrategyStatsOptions {
314                mass_matrix: (),
315                step_size: (),
316            },
317            hamiltonian: (),
318            point: (),
319        }
320    }
321}
322
323impl Settings for DiagGradNutsSettings {
324    type Chain<M: Math> = DiagGradNutsChain<M>;
325
326    fn new_chain<M: Math, R: Rng + ?Sized>(
327        &self,
328        chain: u64,
329        mut math: M,
330        mut rng: &mut R,
331    ) -> Self::Chain<M> {
332        let num_tune = self.num_tune;
333        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
334        let mass_matrix = DiagMassMatrix::new(
335            &mut math,
336            self.adapt_options.mass_matrix_options.store_mass_matrix,
337        );
338        let max_energy_error = self.max_energy_error;
339        let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
340
341        let options = NutsOptions {
342            maxdepth: self.maxdepth,
343            mindepth: self.mindepth,
344            store_gradient: self.store_gradient,
345            store_divergences: self.store_divergences,
346            store_unconstrained: self.store_unconstrained,
347            check_turning: self.check_turning,
348        };
349
350        let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
351
352        NutsChain::new(
353            math,
354            potential,
355            strategy,
356            options,
357            rng,
358            chain,
359            self.stats_options(),
360        )
361    }
362
363    fn hint_num_tune(&self) -> usize {
364        self.num_tune as _
365    }
366
367    fn hint_num_draws(&self) -> usize {
368        self.num_draws as _
369    }
370
371    fn num_chains(&self) -> usize {
372        self.num_chains
373    }
374
375    fn seed(&self) -> u64 {
376        self.seed
377    }
378
379    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
380        StatOptions {
381            adapt: GlobalStrategyStatsOptions {
382                mass_matrix: (),
383                step_size: (),
384            },
385            hamiltonian: (),
386            point: (),
387        }
388    }
389}
390
391impl Settings for TransformedNutsSettings {
392    type Chain<M: Math> = TransformingNutsChain<M>;
393
394    fn new_chain<M: Math, R: Rng + ?Sized>(
395        &self,
396        chain: u64,
397        mut math: M,
398        mut rng: &mut R,
399    ) -> Self::Chain<M> {
400        let num_tune = self.num_tune;
401        let max_energy_error = self.max_energy_error;
402
403        let strategy = TransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
404        let hamiltonian = TransformedHamiltonian::new(&mut math, max_energy_error);
405
406        let options = NutsOptions {
407            maxdepth: self.maxdepth,
408            mindepth: self.mindepth,
409            store_gradient: self.store_gradient,
410            store_divergences: self.store_divergences,
411            store_unconstrained: self.store_unconstrained,
412            check_turning: self.check_turning,
413        };
414
415        let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
416        NutsChain::new(
417            math,
418            hamiltonian,
419            strategy,
420            options,
421            rng,
422            chain,
423            self.stats_options(),
424        )
425    }
426
427    fn hint_num_tune(&self) -> usize {
428        self.num_tune as _
429    }
430
431    fn hint_num_draws(&self) -> usize {
432        self.num_draws as _
433    }
434
435    fn num_chains(&self) -> usize {
436        self.num_chains
437    }
438
439    fn seed(&self) -> u64 {
440        self.seed
441    }
442
443    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
444        // TODO make extra config
445        let point = TransformedPointStatsOptions {
446            store_transformed: self.store_unconstrained,
447        };
448        StatOptions {
449            adapt: (),
450            hamiltonian: (),
451            point,
452        }
453    }
454}
455
456pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
457    math: M,
458    settings: DiagGradNutsSettings,
459    start: &[f64],
460    draws: u64,
461    chain: u64,
462    rng: &mut R,
463) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
464    let mut sampler = settings.new_chain(chain, math, rng);
465    sampler.set_position(start)?;
466    Ok((0..draws).map(move |_| sampler.draw()))
467}
468
469#[non_exhaustive]
470#[derive(Clone, Debug)]
471pub struct ChainProgress {
472    pub finished_draws: usize,
473    pub total_draws: usize,
474    pub divergences: usize,
475    pub tuning: bool,
476    pub started: bool,
477    pub latest_num_steps: usize,
478    pub total_num_steps: usize,
479    pub step_size: f64,
480    pub runtime: Duration,
481    pub divergent_draws: Vec<usize>,
482}
483
484impl ChainProgress {
485    fn new(total: usize) -> Self {
486        Self {
487            finished_draws: 0,
488            total_draws: total,
489            divergences: 0,
490            tuning: true,
491            started: false,
492            latest_num_steps: 0,
493            step_size: 0f64,
494            total_num_steps: 0,
495            runtime: Duration::ZERO,
496            divergent_draws: Vec::new(),
497        }
498    }
499
500    fn update(&mut self, stats: &Progress, draw_duration: Duration) {
501        if stats.diverging & !stats.tuning {
502            self.divergences += 1;
503            self.divergent_draws.push(self.finished_draws);
504        }
505        self.finished_draws += 1;
506        self.tuning = stats.tuning;
507
508        self.latest_num_steps = stats.num_steps as usize;
509        self.total_num_steps += stats.num_steps as usize;
510        self.step_size = stats.step_size;
511        self.runtime += draw_duration;
512    }
513}
514
515enum ChainCommand {
516    Resume,
517    Pause,
518}
519
520struct ChainProcess<T>
521where
522    T: TraceStorage,
523{
524    stop_marker: Sender<ChainCommand>,
525    trace: Arc<Mutex<Option<T::ChainStorage>>>,
526    progress: Arc<Mutex<ChainProgress>>,
527}
528
529impl<T: TraceStorage> ChainProcess<T> {
530    fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
531        let finalized_chain_traces = chains
532            .into_iter()
533            .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
534            .map(|chain| chain.finalize())
535            .collect_vec();
536        trace.finalize(finalized_chain_traces)
537    }
538
539    fn progress(&self) -> ChainProgress {
540        self.progress.lock().expect("Poisoned lock").clone()
541    }
542
543    fn resume(&self) -> Result<()> {
544        self.stop_marker.send(ChainCommand::Resume)?;
545        Ok(())
546    }
547
548    fn pause(&self) -> Result<()> {
549        self.stop_marker.send(ChainCommand::Pause)?;
550        Ok(())
551    }
552
553    fn start<'model, M: Model, S: Settings>(
554        model: &'model M,
555        chain_trace: T::ChainStorage,
556        chain_id: u64,
557        seed: u64,
558        settings: &'model S,
559        scope: &ScopeFifo<'model>,
560        results: Sender<Result<()>>,
561    ) -> Result<Self> {
562        let (stop_marker_tx, stop_marker_rx) = channel();
563
564        let mut rng = ChaCha8Rng::seed_from_u64(seed);
565        rng.set_stream(chain_id + 1);
566
567        let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
568        let progress = Arc::new(Mutex::new(ChainProgress::new(
569            settings.hint_num_draws() + settings.hint_num_tune(),
570        )));
571
572        let trace_inner = chain_trace.clone();
573        let progress_inner = progress.clone();
574
575        scope.spawn_fifo(move |_| {
576            let chain_trace = trace_inner;
577            let progress = progress_inner;
578
579            let mut sample = move || {
580                let logp = model
581                    .math(&mut rng)
582                    .context("Failed to create model density")?;
583                let dim = logp.dim();
584
585                let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
586
587                progress.lock().expect("Poisoned mutex").started = true;
588
589                let mut initval = vec![0f64; dim];
590                // TODO maxtries
591                let mut error = None;
592                for _ in 0..500 {
593                    model
594                        .init_position(&mut rng, &mut initval)
595                        .context("Failed to generate a new initial position")?;
596                    if let Err(err) = sampler.set_position(&initval) {
597                        error = Some(err);
598                        continue;
599                    }
600                    error = None;
601                    break;
602                }
603
604                if let Some(error) = error {
605                    return Err(error.context("All initialization points failed"));
606                }
607
608                let draws = settings.hint_num_tune() + settings.hint_num_draws();
609
610                let mut msg = stop_marker_rx.try_recv();
611                let mut draw = 0;
612                loop {
613                    match msg {
614                        // The remote end is dead
615                        Err(TryRecvError::Disconnected) => {
616                            break;
617                        }
618                        Err(TryRecvError::Empty) => {}
619                        Ok(ChainCommand::Pause) => {
620                            msg = stop_marker_rx.recv().map_err(|e| e.into());
621                            continue;
622                        }
623                        Ok(ChainCommand::Resume) => {}
624                    }
625
626                    let now = Instant::now();
627                    //let (point, info) = sampler.draw().unwrap();
628                    let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
629
630                    let mut guard = chain_trace
631                        .lock()
632                        .expect("Could not unlock trace lock. Poisoned mutex");
633
634                    let Some(trace_val) = guard.as_mut() else {
635                        // The trace was removed by controller thread. We can stop sampling
636                        break;
637                    };
638                    progress
639                        .lock()
640                        .expect("Poisoned mutex")
641                        .update(&info, now.elapsed());
642
643                    let math = sampler.math();
644                    let dims = StatsDims::from(math.deref());
645                    trace_val.record_sample(
646                        settings,
647                        stats.get_all(&dims),
648                        draw_data.get_all(math.deref()),
649                        &info,
650                    )?;
651
652                    draw += 1;
653                    if draw == draws {
654                        break;
655                    }
656
657                    msg = stop_marker_rx.try_recv();
658                }
659                Ok(())
660            };
661
662            let result = sample();
663
664            // We intentionally ignore errors here, because this means some other
665            // chain already failed, and should have reported the error.
666            let _ = results.send(result);
667            drop(results);
668        });
669
670        Ok(Self {
671            trace: chain_trace,
672            stop_marker: stop_marker_tx,
673            progress,
674        })
675    }
676
677    fn flush(&self) -> Result<()> {
678        self.trace
679            .lock()
680            .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
681            .context("Could not flush trace")?
682            .as_mut()
683            .map(|v| v.flush())
684            .transpose()?;
685        Ok(())
686    }
687}
688
689#[derive(Debug)]
690enum SamplerCommand {
691    Pause,
692    Continue,
693    Progress,
694    Flush,
695    Inspect,
696}
697
698enum SamplerResponse<T: Send + 'static> {
699    Ok(),
700    Progress(Box<[ChainProgress]>),
701    Inspect(T),
702}
703
704pub enum SamplerWaitResult<F: Send + 'static> {
705    Trace(F),
706    Timeout(Sampler<F>),
707    Err(anyhow::Error, Option<F>),
708}
709
710pub struct Sampler<F: Send + 'static> {
711    main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
712    commands: SyncSender<SamplerCommand>,
713    responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
714    results: Receiver<Result<()>>,
715}
716
717pub struct ProgressCallback {
718    pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
719    pub rate: Duration,
720}
721
722impl<F: Send + 'static> Sampler<F> {
723    pub fn new<M, S, C, T>(
724        model: M,
725        settings: S,
726        trace_config: C,
727        num_cores: usize,
728        callback: Option<ProgressCallback>,
729    ) -> Result<Self>
730    where
731        S: Settings,
732        C: StorageConfig<Storage = T>,
733        M: Model,
734        T: TraceStorage<Finalized = F>,
735    {
736        let (commands_tx, commands_rx) = sync_channel(0);
737        let (responses_tx, responses_rx) = sync_channel(0);
738        let (results_tx, results_rx) = channel();
739
740        let main_thread = spawn(move || {
741            let pool = ThreadPoolBuilder::new()
742                .num_threads(num_cores + 1) // One more thread because the controller also uses one
743                .thread_name(|i| format!("nutpie-worker-{i}"))
744                .build()
745                .context("Could not start thread pool")?;
746
747            let settings_ref = &settings;
748            let model_ref = &model;
749            let mut callback = callback;
750
751            pool.scope_fifo(move |scope| {
752                let results = results_tx;
753                let mut chains = Vec::with_capacity(settings.num_chains());
754
755                let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
756                rng.set_stream(0);
757
758                let math = model_ref
759                    .math(&mut rng)
760                    .context("Could not create model density")?;
761                let trace = trace_config
762                    .new_trace(settings_ref, &math)
763                    .context("Could not create trace object")?;
764                drop(math);
765
766                for chain_id in 0..settings.num_chains() {
767                    let chain_trace_val = trace
768                        .initialize_trace_for_chain(chain_id as u64)
769                        .context("Failed to create trace object")?;
770                    let chain = ChainProcess::start(
771                        model_ref,
772                        chain_trace_val,
773                        chain_id as u64,
774                        settings.seed(),
775                        settings_ref,
776                        scope,
777                        results.clone(),
778                    );
779                    chains.push(chain);
780                }
781                drop(results);
782
783                let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
784                if let Some(error) = errors.into_iter().next() {
785                    let _ = ChainProcess::finalize_many(trace, chains);
786                    return Err(error).context("Could not start chains");
787                }
788
789                let mut main_loop = || {
790                    let start_time = Instant::now();
791                    let mut pause_start = Instant::now();
792                    let mut pause_time = Duration::ZERO;
793
794                    let mut progress_rate = Duration::MAX;
795                    if let Some(ProgressCallback { callback, rate }) = &mut callback {
796                        let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
797                        callback(start_time.elapsed(), progress.into());
798                        progress_rate = *rate;
799                    }
800                    let mut last_progress = Instant::now();
801                    let mut is_paused = false;
802
803                    loop {
804                        let timeout = progress_rate.checked_sub(last_progress.elapsed());
805                        let timeout = timeout.unwrap_or_else(|| {
806                            if let Some(ProgressCallback { callback, .. }) = &mut callback {
807                                let progress =
808                                    chains.iter().map(|chain| chain.progress()).collect_vec();
809                                let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
810                                if is_paused {
811                                    elapsed = elapsed.saturating_sub(pause_start.elapsed());
812                                }
813                                callback(elapsed, progress.into());
814                            }
815                            last_progress = Instant::now();
816                            progress_rate
817                        });
818
819                        // TODO return when all chains are done
820                        match commands_rx.recv_timeout(timeout) {
821                            Ok(SamplerCommand::Pause) => {
822                                for chain in chains.iter() {
823                                    // This failes if the thread is done.
824                                    // We just want to ignore those threads.
825                                    let _ = chain.pause();
826                                }
827                                if !is_paused {
828                                    pause_start = Instant::now();
829                                }
830                                is_paused = true;
831                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
832                                    anyhow::anyhow!(
833                                        "Could not send pause response to controller thread: {e}"
834                                    )
835                                })?;
836                            }
837                            Ok(SamplerCommand::Continue) => {
838                                for chain in chains.iter() {
839                                    // This failes if the thread is done.
840                                    // We just want to ignore those threads.
841                                    let _ = chain.resume();
842                                }
843                                pause_time += pause_start.elapsed();
844                                is_paused = false;
845                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
846                                    anyhow::anyhow!(
847                                        "Could not send continue response to controller thread: {e}"
848                                    )
849                                })?;
850                            }
851                            Ok(SamplerCommand::Progress) => {
852                                let progress =
853                                    chains.iter().map(|chain| chain.progress()).collect_vec();
854                                responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
855                                    anyhow::anyhow!(
856                                        "Could not send progress response to controller thread: {e}"
857                                    )
858                                })?;
859                            }
860                            Ok(SamplerCommand::Inspect) => {
861                                let traces = chains
862                                    .iter()
863                                    .map(|chain| {
864                                        chain
865                                            .trace
866                                            .lock()
867                                            .expect("Poisoned lock")
868                                            .as_ref()
869                                            .map(|v| v.inspect())
870                                    })
871                                    .flatten()
872                                    .collect_vec();
873                                let finalized_trace = trace.inspect(traces)?;
874                                responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
875                                    anyhow::anyhow!(
876                                        "Could not send inspect response to controller thread: {e}"
877                                    )
878                                })?;
879                            }
880                            Ok(SamplerCommand::Flush) => {
881                                for chain in chains.iter() {
882                                    chain.flush()?;
883                                }
884                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
885                                    anyhow::anyhow!(
886                                        "Could not send flush response to controller thread: {e}"
887                                    )
888                                })?;
889                            }
890                            Err(RecvTimeoutError::Timeout) => {}
891                            Err(RecvTimeoutError::Disconnected) => {
892                                if let Some(ProgressCallback { callback, .. }) = &mut callback {
893                                    let progress =
894                                        chains.iter().map(|chain| chain.progress()).collect_vec();
895                                    let mut elapsed =
896                                        start_time.elapsed().saturating_sub(pause_time);
897                                    if is_paused {
898                                        elapsed = elapsed.saturating_sub(pause_start.elapsed());
899                                    }
900                                    callback(elapsed, progress.into());
901                                }
902                                return Ok(());
903                            }
904                        };
905                    }
906                };
907                let result: Result<()> = main_loop();
908                // Run finalization even if something failed
909                let output = ChainProcess::finalize_many(trace, chains)?;
910
911                result?;
912                Ok(output)
913            })
914        });
915
916        Ok(Self {
917            main_thread,
918            commands: commands_tx,
919            responses: responses_rx,
920            results: results_rx,
921        })
922    }
923
924    pub fn pause(&mut self) -> Result<()> {
925        self.commands
926            .send(SamplerCommand::Pause)
927            .context("Could not send pause command to controller thread")?;
928        let response = self
929            .responses
930            .recv()
931            .context("Could not recieve pause response from controller thread")?;
932        let SamplerResponse::Ok() = response else {
933            bail!("Got invalid response from sample controller thread");
934        };
935        Ok(())
936    }
937
938    pub fn resume(&mut self) -> Result<()> {
939        self.commands.send(SamplerCommand::Continue)?;
940        let response = self.responses.recv()?;
941        let SamplerResponse::Ok() = response else {
942            bail!("Got invalid response from sample controller thread");
943        };
944        Ok(())
945    }
946
947    pub fn flush(&mut self) -> Result<()> {
948        self.commands.send(SamplerCommand::Flush)?;
949        let response = self
950            .responses
951            .recv()
952            .context("Could not recieve flush response from controller thread")?;
953        let SamplerResponse::Ok() = response else {
954            bail!("Got invalid response from sample controller thread");
955        };
956        Ok(())
957    }
958
959    pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
960        self.commands.send(SamplerCommand::Inspect)?;
961        let response = self
962            .responses
963            .recv()
964            .context("Could not recieve inspect response from controller thread")?;
965        let SamplerResponse::Inspect(trace) = response else {
966            bail!("Got invalid response from sample controller thread");
967        };
968        Ok(trace)
969    }
970
971    pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
972        drop(self.commands);
973        let result = self.main_thread.join();
974        match result {
975            Err(payload) => std::panic::resume_unwind(payload),
976            Ok(Ok(val)) => Ok(val),
977            Ok(Err(err)) => Err(err),
978        }
979    }
980
981    pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
982        let start = Instant::now();
983        let mut remaining = Some(timeout);
984        while remaining.is_some() {
985            match self.results.recv_timeout(timeout) {
986                Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
987                Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
988                Err(RecvTimeoutError::Disconnected) => match self.abort() {
989                    Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
990                    Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
991                    Err(err) => return SamplerWaitResult::Err(err, None),
992                },
993                Err(RecvTimeoutError::Timeout) => break,
994            }
995        }
996        SamplerWaitResult::Timeout(self)
997    }
998
999    pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1000        self.commands.send(SamplerCommand::Progress)?;
1001        let response = self.responses.recv()?;
1002        let SamplerResponse::Progress(progress) = response else {
1003            bail!("Got invalid response from sample controller thread");
1004        };
1005        Ok(progress)
1006    }
1007}
1008
1009#[cfg(test)]
1010pub mod test_logps {
1011
1012    use std::collections::HashMap;
1013
1014    use crate::{
1015        Model,
1016        cpu_math::{CpuLogpFunc, CpuMath},
1017        math_base::LogpError,
1018    };
1019    use anyhow::Result;
1020    use nuts_storable::HasDims;
1021    use rand::Rng;
1022    use thiserror::Error;
1023
1024    #[derive(Clone, Debug)]
1025    pub struct NormalLogp {
1026        pub dim: usize,
1027        pub mu: f64,
1028    }
1029
1030    #[derive(Error, Debug)]
1031    pub enum NormalLogpError {}
1032
1033    impl LogpError for NormalLogpError {
1034        fn is_recoverable(&self) -> bool {
1035            false
1036        }
1037    }
1038
1039    impl HasDims for &NormalLogp {
1040        fn dim_sizes(&self) -> HashMap<String, u64> {
1041            vec![
1042                ("unconstrained_parameter".to_string(), self.dim as u64),
1043                ("dim".to_string(), self.dim as u64),
1044            ]
1045            .into_iter()
1046            .collect()
1047        }
1048    }
1049
1050    impl CpuLogpFunc for &NormalLogp {
1051        type LogpError = NormalLogpError;
1052        type FlowParameters = ();
1053        type ExpandedVector = Vec<f64>;
1054
1055        fn dim(&self) -> usize {
1056            self.dim
1057        }
1058
1059        fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
1060            let n = position.len();
1061            assert!(gradient.len() == n);
1062
1063            let mut logp = 0f64;
1064            for (p, g) in position.iter().zip(gradient.iter_mut()) {
1065                let val = self.mu - p;
1066                logp -= val * val / 2.;
1067                *g = val;
1068            }
1069
1070            Ok(logp)
1071        }
1072
1073        fn expand_vector<R>(
1074            &mut self,
1075            _rng: &mut R,
1076            array: &[f64],
1077        ) -> std::result::Result<Self::ExpandedVector, crate::cpu_math::CpuMathError>
1078        where
1079            R: rand::Rng + ?Sized,
1080        {
1081            Ok(array.to_vec())
1082        }
1083
1084        fn inv_transform_normalize(
1085            &mut self,
1086            _params: &Self::FlowParameters,
1087            _untransformed_position: &[f64],
1088            _untransofrmed_gradient: &[f64],
1089            _transformed_position: &mut [f64],
1090            _transformed_gradient: &mut [f64],
1091        ) -> std::result::Result<f64, Self::LogpError> {
1092            unimplemented!()
1093        }
1094
1095        fn init_from_untransformed_position(
1096            &mut self,
1097            _params: &Self::FlowParameters,
1098            _untransformed_position: &[f64],
1099            _untransformed_gradient: &mut [f64],
1100            _transformed_position: &mut [f64],
1101            _transformed_gradient: &mut [f64],
1102        ) -> std::result::Result<(f64, f64), Self::LogpError> {
1103            unimplemented!()
1104        }
1105
1106        fn init_from_transformed_position(
1107            &mut self,
1108            _params: &Self::FlowParameters,
1109            _untransformed_position: &mut [f64],
1110            _untransformed_gradient: &mut [f64],
1111            _transformed_position: &[f64],
1112            _transformed_gradient: &mut [f64],
1113        ) -> std::result::Result<(f64, f64), Self::LogpError> {
1114            unimplemented!()
1115        }
1116
1117        fn update_transformation<'b, R: rand::Rng + ?Sized>(
1118            &'b mut self,
1119            _rng: &mut R,
1120            _untransformed_positions: impl Iterator<Item = &'b [f64]>,
1121            _untransformed_gradients: impl Iterator<Item = &'b [f64]>,
1122            _untransformed_logp: impl Iterator<Item = &'b f64>,
1123            _params: &'b mut Self::FlowParameters,
1124        ) -> std::result::Result<(), Self::LogpError> {
1125            unimplemented!()
1126        }
1127
1128        fn new_transformation<R: rand::Rng + ?Sized>(
1129            &mut self,
1130            _rng: &mut R,
1131            _untransformed_position: &[f64],
1132            _untransfogmed_gradient: &[f64],
1133            _chain: u64,
1134        ) -> std::result::Result<Self::FlowParameters, Self::LogpError> {
1135            unimplemented!()
1136        }
1137
1138        fn transformation_id(
1139            &self,
1140            _params: &Self::FlowParameters,
1141        ) -> std::result::Result<i64, Self::LogpError> {
1142            unimplemented!()
1143        }
1144    }
1145
1146    pub struct CpuModel<F> {
1147        logp: F,
1148    }
1149
1150    impl<F> CpuModel<F> {
1151        pub fn new(logp: F) -> Self {
1152            Self { logp }
1153        }
1154    }
1155
1156    impl<F> Model for CpuModel<F>
1157    where
1158        F: Send + Sync + 'static,
1159        for<'a> &'a F: CpuLogpFunc,
1160    {
1161        type Math<'model> = CpuMath<&'model F>;
1162
1163        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1164            Ok(CpuMath::new(&self.logp))
1165        }
1166
1167        fn init_position<R: rand::prelude::Rng + ?Sized>(
1168            &self,
1169            _rng: &mut R,
1170            position: &mut [f64],
1171        ) -> Result<()> {
1172            position.iter_mut().for_each(|x| *x = 0.);
1173            Ok(())
1174        }
1175    }
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180    use std::{
1181        sync::Arc,
1182        time::{Duration, Instant},
1183    };
1184
1185    use super::test_logps::NormalLogp;
1186    use crate::{
1187        Chain, DiagGradNutsSettings, Sampler, ZarrConfig,
1188        cpu_math::CpuMath,
1189        sample_sequentially,
1190        sampler::{Settings, test_logps::CpuModel},
1191    };
1192
1193    use anyhow::Result;
1194    use itertools::Itertools;
1195    use pretty_assertions::assert_eq;
1196    use rand::{SeedableRng, rngs::StdRng};
1197    use zarrs::storage::store::MemoryStore;
1198
1199    #[test]
1200    fn sample_chain() -> Result<()> {
1201        let logp = NormalLogp { dim: 10, mu: 0.1 };
1202        let math = CpuMath::new(&logp);
1203        let settings = DiagGradNutsSettings {
1204            num_tune: 100,
1205            num_draws: 100,
1206            ..Default::default()
1207        };
1208        let start = vec![0.2; 10];
1209
1210        let mut rng = StdRng::seed_from_u64(42);
1211
1212        let mut chain = settings.new_chain(0, math, &mut rng);
1213
1214        let (_draw, info) = chain.draw()?;
1215        assert!(info.tuning);
1216        assert_eq!(info.draw, 0);
1217
1218        let math = CpuMath::new(&logp);
1219        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1220        let mut draws = chain.collect_vec();
1221        assert_eq!(draws.len(), 200);
1222
1223        let draw0 = draws.remove(100).unwrap();
1224        let (vals, stats) = draw0;
1225        assert_eq!(vals.len(), 10);
1226        assert_eq!(stats.chain, 1);
1227        assert_eq!(stats.draw, 100);
1228        Ok(())
1229    }
1230
1231    #[test]
1232    fn sample_parallel() -> Result<()> {
1233        let logp = NormalLogp { dim: 100, mu: 0.1 };
1234        let settings = DiagGradNutsSettings {
1235            num_tune: 100,
1236            num_draws: 100,
1237            seed: 10,
1238            ..Default::default()
1239        };
1240
1241        let model = CpuModel::new(logp.clone());
1242        let store = MemoryStore::new();
1243
1244        let zarr_config = ZarrConfig::new(Arc::new(store));
1245        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1246        sampler.pause()?;
1247        sampler.pause()?;
1248        // TODO flush trace
1249        sampler.resume()?;
1250        let (ok, _) = sampler.abort()?;
1251        if let Some(err) = ok {
1252            Err(err)?;
1253        }
1254
1255        let store = MemoryStore::new();
1256        let zarr_config = ZarrConfig::new(Arc::new(store));
1257        let model = CpuModel::new(logp.clone());
1258        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1259        sampler.pause()?;
1260        if let (Some(err), _) = sampler.abort()? {
1261            Err(err)?;
1262        }
1263
1264        let store = MemoryStore::new();
1265        let zarr_config = ZarrConfig::new(Arc::new(store));
1266        let model = CpuModel::new(logp.clone());
1267        let start = Instant::now();
1268        let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1269
1270        let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1271            super::SamplerWaitResult::Trace(_) => {
1272                dbg!(start.elapsed());
1273                panic!("finished");
1274            }
1275            super::SamplerWaitResult::Timeout(sampler) => sampler,
1276            super::SamplerWaitResult::Err(_, _) => {
1277                panic!("error")
1278            }
1279        };
1280
1281        for _ in 0..30 {
1282            sampler.progress()?;
1283        }
1284
1285        match sampler.wait_timeout(Duration::from_secs(1)) {
1286            super::SamplerWaitResult::Trace(_) => {
1287                dbg!(start.elapsed());
1288            }
1289            super::SamplerWaitResult::Timeout(_) => {
1290                panic!("timeout")
1291            }
1292            super::SamplerWaitResult::Err(err, _) => Err(err)?,
1293        };
1294
1295        Ok(())
1296    }
1297
1298    #[test]
1299    fn sample_seq() {
1300        let logp = NormalLogp { dim: 10, mu: 0.1 };
1301        let math = CpuMath::new(&logp);
1302        let settings = DiagGradNutsSettings {
1303            num_tune: 100,
1304            num_draws: 100,
1305            ..Default::default()
1306        };
1307        let start = vec![0.2; 10];
1308
1309        let mut rng = StdRng::seed_from_u64(42);
1310
1311        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1312        let mut draws = chain.collect_vec();
1313        assert_eq!(draws.len(), 200);
1314
1315        let draw0 = draws.remove(100).unwrap();
1316        let (vals, stats) = draw0;
1317        assert_eq!(vals.len(), 10);
1318        assert_eq!(stats.chain, 1);
1319        assert_eq!(stats.draw, 100);
1320    }
1321}