nuts_rs/
sampler.rs

1use anyhow::{Context, Result, bail};
2use itertools::Itertools;
3use nuts_storable::{HasDims, Storable, Value};
4use rand::{Rng, SeedableRng, rngs::SmallRng};
5use rand_chacha::ChaCha8Rng;
6use rayon::{ScopeFifo, ThreadPoolBuilder};
7use serde::Serialize;
8use std::{
9    collections::HashMap,
10    fmt::Debug,
11    ops::Deref,
12    sync::{
13        Arc, Mutex,
14        mpsc::{
15            Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
16        },
17    },
18    thread::{JoinHandle, spawn},
19    time::{Duration, Instant},
20};
21
22use crate::{
23    DiagAdaptExpSettings,
24    adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
25    chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
26    euclidean_hamiltonian::EuclideanHamiltonian,
27    mass_matrix::DiagMassMatrix,
28    mass_matrix::Strategy as DiagMassMatrixStrategy,
29    mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
30    math_base::Math,
31    model::Model,
32    nuts::NutsOptions,
33    sampler_stats::{SamplerStats, StatsDims},
34    storage::{ChainStorage, StorageConfig, TraceStorage},
35    transform_adapt_strategy::{TransformAdaptation, TransformedSettings},
36    transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions},
37};
38
39/// All sampler configurations implement this trait
40pub trait Settings:
41    private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + 'static
42{
43    type Chain<M: Math>: Chain<M>;
44
45    fn new_chain<M: Math, R: Rng + ?Sized>(
46        &self,
47        chain: u64,
48        math: M,
49        rng: &mut R,
50    ) -> Self::Chain<M>;
51
52    fn hint_num_tune(&self) -> usize;
53    fn hint_num_draws(&self) -> usize;
54    fn num_chains(&self) -> usize;
55    fn seed(&self) -> u64;
56    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
57
58    fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
59        let dims = StatsDims::from(math);
60        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
61            .into_iter()
62            .map(String::from)
63            .collect()
64    }
65
66    fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
67        <M::ExpandedVector as Storable<_>>::names(math)
68            .into_iter()
69            .map(String::from)
70            .collect()
71    }
72
73    fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
74        self.stat_names(math)
75            .into_iter()
76            .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
77            .collect()
78    }
79
80    fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
81        let dims = StatsDims::from(math);
82        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
83    }
84
85    fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
86        self.data_names(math)
87            .into_iter()
88            .map(|name| (name.clone(), self.data_type(math, &name)))
89            .collect()
90    }
91    fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
92        <M::ExpandedVector as Storable<_>>::item_type(math, name)
93    }
94
95    fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
96        self.stat_names(math)
97            .into_iter()
98            .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
99            .collect()
100    }
101
102    fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
103        let dims = StatsDims::from(math);
104        <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
105            .into_iter()
106            .map(String::from)
107            .collect()
108    }
109
110    fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
111        let dims = StatsDims::from(math);
112        dims.dim_sizes()
113    }
114
115    fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
116        self.data_names(math)
117            .into_iter()
118            .map(|name| (name.clone(), self.data_dims(math, &name)))
119            .collect()
120    }
121
122    fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
123        <M::ExpandedVector as Storable<_>>::dims(math, name)
124            .into_iter()
125            .map(String::from)
126            .collect()
127    }
128
129    fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
130        let dims = StatsDims::from(math);
131        dims.coords()
132    }
133}
134
135#[derive(Debug, Clone)]
136#[non_exhaustive]
137pub struct Progress {
138    pub draw: u64,
139    pub chain: u64,
140    pub diverging: bool,
141    pub tuning: bool,
142    pub step_size: f64,
143    pub num_steps: u64,
144}
145
146mod private {
147    use crate::DiagGradNutsSettings;
148
149    use super::{LowRankNutsSettings, TransformedNutsSettings};
150
151    pub trait Sealed {}
152
153    impl Sealed for DiagGradNutsSettings {}
154
155    impl Sealed for LowRankNutsSettings {}
156
157    impl Sealed for TransformedNutsSettings {}
158}
159
160/// Settings for the NUTS sampler
161#[derive(Debug, Clone, Copy, Serialize)]
162pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
163    /// The number of tuning steps, where we fit the step size and mass matrix.
164    pub num_tune: u64,
165    /// The number of draws after tuning
166    pub num_draws: u64,
167    /// The maximum tree depth during sampling. The number of leapfrog steps
168    /// is smaller than 2 ^ maxdepth.
169    pub maxdepth: u64,
170    /// The minimum tree depth during sampling. The number of leapfrog steps
171    /// is larger than 2 ^ mindepth.
172    pub mindepth: u64,
173    /// Store the gradient in the SampleStats
174    pub store_gradient: bool,
175    /// Store each unconstrained parameter vector in the sampler stats
176    pub store_unconstrained: bool,
177    /// If the energy error is larger than this threshold we treat the leapfrog
178    /// step as a divergence.
179    pub max_energy_error: f64,
180    /// Store detailed information about each divergence in the sampler stats
181    pub store_divergences: bool,
182    /// Settings for mass matrix adaptation.
183    pub adapt_options: A,
184    pub check_turning: bool,
185
186    pub num_chains: usize,
187    pub seed: u64,
188}
189
190pub type DiagGradNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
191pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
192pub type TransformedNutsSettings = NutsSettings<TransformedSettings>;
193
194impl Default for DiagGradNutsSettings {
195    fn default() -> Self {
196        Self {
197            num_tune: 400,
198            num_draws: 1000,
199            maxdepth: 10,
200            mindepth: 0,
201            max_energy_error: 1000f64,
202            store_gradient: false,
203            store_unconstrained: false,
204            store_divergences: false,
205            adapt_options: EuclideanAdaptOptions::default(),
206            check_turning: true,
207            seed: 0,
208            num_chains: 6,
209        }
210    }
211}
212
213impl Default for LowRankNutsSettings {
214    fn default() -> Self {
215        let mut vals = Self {
216            num_tune: 800,
217            num_draws: 1000,
218            maxdepth: 10,
219            mindepth: 0,
220            max_energy_error: 1000f64,
221            store_gradient: false,
222            store_unconstrained: false,
223            store_divergences: false,
224            adapt_options: EuclideanAdaptOptions::default(),
225            check_turning: true,
226            seed: 0,
227            num_chains: 6,
228        };
229        vals.adapt_options.mass_matrix_update_freq = 10;
230        vals
231    }
232}
233
234impl Default for TransformedNutsSettings {
235    fn default() -> Self {
236        Self {
237            num_tune: 1500,
238            num_draws: 1000,
239            maxdepth: 10,
240            mindepth: 0,
241            max_energy_error: 20f64,
242            store_gradient: false,
243            store_unconstrained: false,
244            store_divergences: false,
245            adapt_options: Default::default(),
246            check_turning: true,
247            seed: 0,
248            num_chains: 1,
249        }
250    }
251}
252
253type DiagGradNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, DiagMassMatrixStrategy<M>>>;
254
255type LowRankNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
256
257type TransformingNutsChain<M> = NutsChain<M, SmallRng, TransformAdaptation>;
258
259impl Settings for LowRankNutsSettings {
260    type Chain<M: Math> = LowRankNutsChain<M>;
261
262    fn new_chain<M: Math, R: Rng + ?Sized>(
263        &self,
264        chain: u64,
265        mut math: M,
266        mut rng: &mut R,
267    ) -> Self::Chain<M> {
268        let num_tune = self.num_tune;
269        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
270        let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
271        let max_energy_error = self.max_energy_error;
272        let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
273
274        let options = NutsOptions {
275            maxdepth: self.maxdepth,
276            mindepth: self.mindepth,
277            store_gradient: self.store_gradient,
278            store_divergences: self.store_divergences,
279            store_unconstrained: self.store_unconstrained,
280            check_turning: self.check_turning,
281        };
282
283        let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
284
285        NutsChain::new(
286            math,
287            potential,
288            strategy,
289            options,
290            rng,
291            chain,
292            self.stats_options(),
293        )
294    }
295
296    fn hint_num_tune(&self) -> usize {
297        self.num_tune as _
298    }
299
300    fn hint_num_draws(&self) -> usize {
301        self.num_draws as _
302    }
303
304    fn num_chains(&self) -> usize {
305        self.num_chains
306    }
307
308    fn seed(&self) -> u64 {
309        self.seed
310    }
311
312    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
313        StatOptions {
314            adapt: GlobalStrategyStatsOptions {
315                mass_matrix: (),
316                step_size: (),
317            },
318            hamiltonian: (),
319            point: (),
320        }
321    }
322}
323
324impl Settings for DiagGradNutsSettings {
325    type Chain<M: Math> = DiagGradNutsChain<M>;
326
327    fn new_chain<M: Math, R: Rng + ?Sized>(
328        &self,
329        chain: u64,
330        mut math: M,
331        mut rng: &mut R,
332    ) -> Self::Chain<M> {
333        let num_tune = self.num_tune;
334        let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
335        let mass_matrix = DiagMassMatrix::new(
336            &mut math,
337            self.adapt_options.mass_matrix_options.store_mass_matrix,
338        );
339        let max_energy_error = self.max_energy_error;
340        let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
341
342        let options = NutsOptions {
343            maxdepth: self.maxdepth,
344            mindepth: self.mindepth,
345            store_gradient: self.store_gradient,
346            store_divergences: self.store_divergences,
347            store_unconstrained: self.store_unconstrained,
348            check_turning: self.check_turning,
349        };
350
351        let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
352
353        NutsChain::new(
354            math,
355            potential,
356            strategy,
357            options,
358            rng,
359            chain,
360            self.stats_options(),
361        )
362    }
363
364    fn hint_num_tune(&self) -> usize {
365        self.num_tune as _
366    }
367
368    fn hint_num_draws(&self) -> usize {
369        self.num_draws as _
370    }
371
372    fn num_chains(&self) -> usize {
373        self.num_chains
374    }
375
376    fn seed(&self) -> u64 {
377        self.seed
378    }
379
380    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
381        StatOptions {
382            adapt: GlobalStrategyStatsOptions {
383                mass_matrix: (),
384                step_size: (),
385            },
386            hamiltonian: (),
387            point: (),
388        }
389    }
390}
391
392impl Settings for TransformedNutsSettings {
393    type Chain<M: Math> = TransformingNutsChain<M>;
394
395    fn new_chain<M: Math, R: Rng + ?Sized>(
396        &self,
397        chain: u64,
398        mut math: M,
399        mut rng: &mut R,
400    ) -> Self::Chain<M> {
401        let num_tune = self.num_tune;
402        let max_energy_error = self.max_energy_error;
403
404        let strategy = TransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
405        let hamiltonian = TransformedHamiltonian::new(&mut math, max_energy_error);
406
407        let options = NutsOptions {
408            maxdepth: self.maxdepth,
409            mindepth: self.mindepth,
410            store_gradient: self.store_gradient,
411            store_divergences: self.store_divergences,
412            store_unconstrained: self.store_unconstrained,
413            check_turning: self.check_turning,
414        };
415
416        let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
417        NutsChain::new(
418            math,
419            hamiltonian,
420            strategy,
421            options,
422            rng,
423            chain,
424            self.stats_options(),
425        )
426    }
427
428    fn hint_num_tune(&self) -> usize {
429        self.num_tune as _
430    }
431
432    fn hint_num_draws(&self) -> usize {
433        self.num_draws as _
434    }
435
436    fn num_chains(&self) -> usize {
437        self.num_chains
438    }
439
440    fn seed(&self) -> u64 {
441        self.seed
442    }
443
444    fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
445        // TODO make extra config
446        let point = TransformedPointStatsOptions {
447            store_transformed: self.store_unconstrained,
448        };
449        StatOptions {
450            adapt: (),
451            hamiltonian: (),
452            point,
453        }
454    }
455}
456
457pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
458    math: M,
459    settings: DiagGradNutsSettings,
460    start: &[f64],
461    draws: u64,
462    chain: u64,
463    rng: &mut R,
464) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
465    let mut sampler = settings.new_chain(chain, math, rng);
466    sampler.set_position(start)?;
467    Ok((0..draws).map(move |_| sampler.draw()))
468}
469
470#[non_exhaustive]
471#[derive(Clone, Debug)]
472pub struct ChainProgress {
473    pub finished_draws: usize,
474    pub total_draws: usize,
475    pub divergences: usize,
476    pub tuning: bool,
477    pub started: bool,
478    pub latest_num_steps: usize,
479    pub total_num_steps: usize,
480    pub step_size: f64,
481    pub runtime: Duration,
482    pub divergent_draws: Vec<usize>,
483}
484
485impl ChainProgress {
486    fn new(total: usize) -> Self {
487        Self {
488            finished_draws: 0,
489            total_draws: total,
490            divergences: 0,
491            tuning: true,
492            started: false,
493            latest_num_steps: 0,
494            step_size: 0f64,
495            total_num_steps: 0,
496            runtime: Duration::ZERO,
497            divergent_draws: Vec::new(),
498        }
499    }
500
501    fn update(&mut self, stats: &Progress, draw_duration: Duration) {
502        if stats.diverging & !stats.tuning {
503            self.divergences += 1;
504            self.divergent_draws.push(self.finished_draws);
505        }
506        self.finished_draws += 1;
507        self.tuning = stats.tuning;
508
509        self.latest_num_steps = stats.num_steps as usize;
510        self.total_num_steps += stats.num_steps as usize;
511        self.step_size = stats.step_size;
512        self.runtime += draw_duration;
513    }
514}
515
516enum ChainCommand {
517    Resume,
518    Pause,
519}
520
521struct ChainProcess<T>
522where
523    T: TraceStorage,
524{
525    stop_marker: Sender<ChainCommand>,
526    trace: Arc<Mutex<Option<T::ChainStorage>>>,
527    progress: Arc<Mutex<ChainProgress>>,
528}
529
530impl<T: TraceStorage> ChainProcess<T> {
531    fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
532        let finalized_chain_traces = chains
533            .into_iter()
534            .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
535            .map(|chain| chain.finalize())
536            .collect_vec();
537        trace.finalize(finalized_chain_traces)
538    }
539
540    fn progress(&self) -> ChainProgress {
541        self.progress.lock().expect("Poisoned lock").clone()
542    }
543
544    fn resume(&self) -> Result<()> {
545        self.stop_marker.send(ChainCommand::Resume)?;
546        Ok(())
547    }
548
549    fn pause(&self) -> Result<()> {
550        self.stop_marker.send(ChainCommand::Pause)?;
551        Ok(())
552    }
553
554    fn start<'model, M: Model, S: Settings>(
555        model: &'model M,
556        chain_trace: T::ChainStorage,
557        chain_id: u64,
558        seed: u64,
559        settings: &'model S,
560        scope: &ScopeFifo<'model>,
561        results: Sender<Result<()>>,
562    ) -> Result<Self> {
563        let (stop_marker_tx, stop_marker_rx) = channel();
564
565        let mut rng = ChaCha8Rng::seed_from_u64(seed);
566        rng.set_stream(chain_id + 1);
567
568        let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
569        let progress = Arc::new(Mutex::new(ChainProgress::new(
570            settings.hint_num_draws() + settings.hint_num_tune(),
571        )));
572
573        let trace_inner = chain_trace.clone();
574        let progress_inner = progress.clone();
575
576        scope.spawn_fifo(move |_| {
577            let chain_trace = trace_inner;
578            let progress = progress_inner;
579
580            let mut sample = move || {
581                let logp = model
582                    .math(&mut rng)
583                    .context("Failed to create model density")?;
584                let dim = logp.dim();
585
586                let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
587
588                progress.lock().expect("Poisoned mutex").started = true;
589
590                let mut initval = vec![0f64; dim];
591                // TODO maxtries
592                let mut error = None;
593                for _ in 0..500 {
594                    model
595                        .init_position(&mut rng, &mut initval)
596                        .context("Failed to generate a new initial position")?;
597                    if let Err(err) = sampler.set_position(&initval) {
598                        error = Some(err);
599                        continue;
600                    }
601                    error = None;
602                    break;
603                }
604
605                if let Some(error) = error {
606                    return Err(error.context("All initialization points failed"));
607                }
608
609                let draws = settings.hint_num_tune() + settings.hint_num_draws();
610
611                let mut msg = stop_marker_rx.try_recv();
612                let mut draw = 0;
613                loop {
614                    match msg {
615                        // The remote end is dead
616                        Err(TryRecvError::Disconnected) => {
617                            break;
618                        }
619                        Err(TryRecvError::Empty) => {}
620                        Ok(ChainCommand::Pause) => {
621                            msg = stop_marker_rx.recv().map_err(|e| e.into());
622                            continue;
623                        }
624                        Ok(ChainCommand::Resume) => {}
625                    }
626
627                    let now = Instant::now();
628                    //let (point, info) = sampler.draw().unwrap();
629                    let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
630
631                    let mut guard = chain_trace
632                        .lock()
633                        .expect("Could not unlock trace lock. Poisoned mutex");
634
635                    let Some(trace_val) = guard.as_mut() else {
636                        // The trace was removed by controller thread. We can stop sampling
637                        break;
638                    };
639                    progress
640                        .lock()
641                        .expect("Poisoned mutex")
642                        .update(&info, now.elapsed());
643
644                    let math = sampler.math();
645                    let dims = StatsDims::from(math.deref());
646                    trace_val.record_sample(
647                        settings,
648                        stats.get_all(&dims),
649                        draw_data.get_all(math.deref()),
650                        &info,
651                    )?;
652
653                    draw += 1;
654                    if draw == draws {
655                        break;
656                    }
657
658                    msg = stop_marker_rx.try_recv();
659                }
660                Ok(())
661            };
662
663            let result = sample();
664
665            // We intentionally ignore errors here, because this means some other
666            // chain already failed, and should have reported the error.
667            let _ = results.send(result);
668            drop(results);
669        });
670
671        Ok(Self {
672            trace: chain_trace,
673            stop_marker: stop_marker_tx,
674            progress,
675        })
676    }
677
678    fn flush(&self) -> Result<()> {
679        self.trace
680            .lock()
681            .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
682            .context("Could not flush trace")?
683            .as_mut()
684            .map(|v| v.flush())
685            .transpose()?;
686        Ok(())
687    }
688}
689
690#[derive(Debug)]
691enum SamplerCommand {
692    Pause,
693    Continue,
694    Progress,
695    Flush,
696    Inspect,
697}
698
699enum SamplerResponse<T: Send + 'static> {
700    Ok(),
701    Progress(Box<[ChainProgress]>),
702    Inspect(T),
703}
704
705pub enum SamplerWaitResult<F: Send + 'static> {
706    Trace(F),
707    Timeout(Sampler<F>),
708    Err(anyhow::Error, Option<F>),
709}
710
711pub struct Sampler<F: Send + 'static> {
712    main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
713    commands: SyncSender<SamplerCommand>,
714    responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
715    results: Receiver<Result<()>>,
716}
717
718pub struct ProgressCallback {
719    pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
720    pub rate: Duration,
721}
722
723impl<F: Send + 'static> Sampler<F> {
724    pub fn new<M, S, C, T>(
725        model: M,
726        settings: S,
727        trace_config: C,
728        num_cores: usize,
729        callback: Option<ProgressCallback>,
730    ) -> Result<Self>
731    where
732        S: Settings,
733        C: StorageConfig<Storage = T>,
734        M: Model,
735        T: TraceStorage<Finalized = F>,
736    {
737        let (commands_tx, commands_rx) = sync_channel(0);
738        let (responses_tx, responses_rx) = sync_channel(0);
739        let (results_tx, results_rx) = channel();
740
741        let main_thread = spawn(move || {
742            let pool = ThreadPoolBuilder::new()
743                .num_threads(num_cores + 1) // One more thread because the controller also uses one
744                .thread_name(|i| format!("nutpie-worker-{i}"))
745                .build()
746                .context("Could not start thread pool")?;
747
748            let settings_ref = &settings;
749            let model_ref = &model;
750            let mut callback = callback;
751
752            pool.scope_fifo(move |scope| {
753                let results = results_tx;
754                let mut chains = Vec::with_capacity(settings.num_chains());
755
756                let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
757                rng.set_stream(0);
758
759                let math = model_ref
760                    .math(&mut rng)
761                    .context("Could not create model density")?;
762                let trace = trace_config
763                    .new_trace(settings_ref, &math)
764                    .context("Could not create trace object")?;
765                drop(math);
766
767                for chain_id in 0..settings.num_chains() {
768                    let chain_trace_val = trace
769                        .initialize_trace_for_chain(chain_id as u64)
770                        .context("Failed to create trace object")?;
771                    let chain = ChainProcess::start(
772                        model_ref,
773                        chain_trace_val,
774                        chain_id as u64,
775                        settings.seed(),
776                        settings_ref,
777                        scope,
778                        results.clone(),
779                    );
780                    chains.push(chain);
781                }
782                drop(results);
783
784                let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
785                if let Some(error) = errors.into_iter().next() {
786                    let _ = ChainProcess::finalize_many(trace, chains);
787                    return Err(error).context("Could not start chains");
788                }
789
790                let mut main_loop = || {
791                    let start_time = Instant::now();
792                    let mut pause_start = Instant::now();
793                    let mut pause_time = Duration::ZERO;
794
795                    let mut progress_rate = Duration::MAX;
796                    if let Some(ProgressCallback { callback, rate }) = &mut callback {
797                        let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
798                        callback(start_time.elapsed(), progress.into());
799                        progress_rate = *rate;
800                    }
801                    let mut last_progress = Instant::now();
802                    let mut is_paused = false;
803
804                    loop {
805                        let timeout = progress_rate.checked_sub(last_progress.elapsed());
806                        let timeout = timeout.unwrap_or_else(|| {
807                            if let Some(ProgressCallback { callback, .. }) = &mut callback {
808                                let progress =
809                                    chains.iter().map(|chain| chain.progress()).collect_vec();
810                                let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
811                                if is_paused {
812                                    elapsed = elapsed.saturating_sub(pause_start.elapsed());
813                                }
814                                callback(elapsed, progress.into());
815                            }
816                            last_progress = Instant::now();
817                            progress_rate
818                        });
819
820                        // TODO return when all chains are done
821                        match commands_rx.recv_timeout(timeout) {
822                            Ok(SamplerCommand::Pause) => {
823                                for chain in chains.iter() {
824                                    // This failes if the thread is done.
825                                    // We just want to ignore those threads.
826                                    let _ = chain.pause();
827                                }
828                                if !is_paused {
829                                    pause_start = Instant::now();
830                                }
831                                is_paused = true;
832                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
833                                    anyhow::anyhow!(
834                                        "Could not send pause response to controller thread: {e}"
835                                    )
836                                })?;
837                            }
838                            Ok(SamplerCommand::Continue) => {
839                                for chain in chains.iter() {
840                                    // This failes if the thread is done.
841                                    // We just want to ignore those threads.
842                                    let _ = chain.resume();
843                                }
844                                pause_time += pause_start.elapsed();
845                                is_paused = false;
846                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
847                                    anyhow::anyhow!(
848                                        "Could not send continue response to controller thread: {e}"
849                                    )
850                                })?;
851                            }
852                            Ok(SamplerCommand::Progress) => {
853                                let progress =
854                                    chains.iter().map(|chain| chain.progress()).collect_vec();
855                                responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
856                                    anyhow::anyhow!(
857                                        "Could not send progress response to controller thread: {e}"
858                                    )
859                                })?;
860                            }
861                            Ok(SamplerCommand::Inspect) => {
862                                let traces = chains
863                                    .iter()
864                                    .map(|chain| {
865                                        chain
866                                            .trace
867                                            .lock()
868                                            .expect("Poisoned lock")
869                                            .as_ref()
870                                            .map(|v| v.inspect())
871                                    })
872                                    .flatten()
873                                    .collect_vec();
874                                let finalized_trace = trace.inspect(traces)?;
875                                responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
876                                    anyhow::anyhow!(
877                                        "Could not send inspect response to controller thread: {e}"
878                                    )
879                                })?;
880                            }
881                            Ok(SamplerCommand::Flush) => {
882                                for chain in chains.iter() {
883                                    chain.flush()?;
884                                }
885                                responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
886                                    anyhow::anyhow!(
887                                        "Could not send flush response to controller thread: {e}"
888                                    )
889                                })?;
890                            }
891                            Err(RecvTimeoutError::Timeout) => {}
892                            Err(RecvTimeoutError::Disconnected) => {
893                                if let Some(ProgressCallback { callback, .. }) = &mut callback {
894                                    let progress =
895                                        chains.iter().map(|chain| chain.progress()).collect_vec();
896                                    let mut elapsed =
897                                        start_time.elapsed().saturating_sub(pause_time);
898                                    if is_paused {
899                                        elapsed = elapsed.saturating_sub(pause_start.elapsed());
900                                    }
901                                    callback(elapsed, progress.into());
902                                }
903                                return Ok(());
904                            }
905                        };
906                    }
907                };
908                let result: Result<()> = main_loop();
909                // Run finalization even if something failed
910                let output = ChainProcess::finalize_many(trace, chains)?;
911
912                result?;
913                Ok(output)
914            })
915        });
916
917        Ok(Self {
918            main_thread,
919            commands: commands_tx,
920            responses: responses_rx,
921            results: results_rx,
922        })
923    }
924
925    pub fn pause(&mut self) -> Result<()> {
926        self.commands
927            .send(SamplerCommand::Pause)
928            .context("Could not send pause command to controller thread")?;
929        let response = self
930            .responses
931            .recv()
932            .context("Could not recieve pause response from controller thread")?;
933        let SamplerResponse::Ok() = response else {
934            bail!("Got invalid response from sample controller thread");
935        };
936        Ok(())
937    }
938
939    pub fn resume(&mut self) -> Result<()> {
940        self.commands.send(SamplerCommand::Continue)?;
941        let response = self.responses.recv()?;
942        let SamplerResponse::Ok() = response else {
943            bail!("Got invalid response from sample controller thread");
944        };
945        Ok(())
946    }
947
948    pub fn flush(&mut self) -> Result<()> {
949        self.commands.send(SamplerCommand::Flush)?;
950        let response = self
951            .responses
952            .recv()
953            .context("Could not recieve flush response from controller thread")?;
954        let SamplerResponse::Ok() = response else {
955            bail!("Got invalid response from sample controller thread");
956        };
957        Ok(())
958    }
959
960    pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
961        self.commands.send(SamplerCommand::Inspect)?;
962        let response = self
963            .responses
964            .recv()
965            .context("Could not recieve inspect response from controller thread")?;
966        let SamplerResponse::Inspect(trace) = response else {
967            bail!("Got invalid response from sample controller thread");
968        };
969        Ok(trace)
970    }
971
972    pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
973        drop(self.commands);
974        let result = self.main_thread.join();
975        match result {
976            Err(payload) => std::panic::resume_unwind(payload),
977            Ok(Ok(val)) => Ok(val),
978            Ok(Err(err)) => Err(err),
979        }
980    }
981
982    pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
983        let start = Instant::now();
984        let mut remaining = Some(timeout);
985        while remaining.is_some() {
986            match self.results.recv_timeout(timeout) {
987                Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
988                Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
989                Err(RecvTimeoutError::Disconnected) => match self.abort() {
990                    Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
991                    Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
992                    Err(err) => return SamplerWaitResult::Err(err, None),
993                },
994                Err(RecvTimeoutError::Timeout) => break,
995            }
996        }
997        SamplerWaitResult::Timeout(self)
998    }
999
1000    pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1001        self.commands.send(SamplerCommand::Progress)?;
1002        let response = self.responses.recv()?;
1003        let SamplerResponse::Progress(progress) = response else {
1004            bail!("Got invalid response from sample controller thread");
1005        };
1006        Ok(progress)
1007    }
1008}
1009
1010#[cfg(test)]
1011pub mod test_logps {
1012
1013    use std::collections::HashMap;
1014
1015    use crate::{
1016        Model,
1017        cpu_math::{CpuLogpFunc, CpuMath},
1018        math_base::LogpError,
1019    };
1020    use anyhow::Result;
1021    use nuts_storable::HasDims;
1022    use rand::Rng;
1023    use thiserror::Error;
1024
1025    #[derive(Clone, Debug)]
1026    pub struct NormalLogp {
1027        pub dim: usize,
1028        pub mu: f64,
1029    }
1030
1031    #[derive(Error, Debug)]
1032    pub enum NormalLogpError {}
1033
1034    impl LogpError for NormalLogpError {
1035        fn is_recoverable(&self) -> bool {
1036            false
1037        }
1038    }
1039
1040    impl HasDims for &NormalLogp {
1041        fn dim_sizes(&self) -> HashMap<String, u64> {
1042            vec![
1043                ("unconstrained_parameter".to_string(), self.dim as u64),
1044                ("dim".to_string(), self.dim as u64),
1045            ]
1046            .into_iter()
1047            .collect()
1048        }
1049    }
1050
1051    impl CpuLogpFunc for &NormalLogp {
1052        type LogpError = NormalLogpError;
1053        type FlowParameters = ();
1054        type ExpandedVector = Vec<f64>;
1055
1056        fn dim(&self) -> usize {
1057            self.dim
1058        }
1059
1060        fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
1061            let n = position.len();
1062            assert!(gradient.len() == n);
1063
1064            let mut logp = 0f64;
1065            for (p, g) in position.iter().zip(gradient.iter_mut()) {
1066                let val = self.mu - p;
1067                logp -= val * val / 2.;
1068                *g = val;
1069            }
1070
1071            Ok(logp)
1072        }
1073
1074        fn expand_vector<R>(
1075            &mut self,
1076            _rng: &mut R,
1077            array: &[f64],
1078        ) -> std::result::Result<Self::ExpandedVector, crate::cpu_math::CpuMathError>
1079        where
1080            R: rand::Rng + ?Sized,
1081        {
1082            Ok(array.to_vec())
1083        }
1084
1085        fn inv_transform_normalize(
1086            &mut self,
1087            _params: &Self::FlowParameters,
1088            _untransformed_position: &[f64],
1089            _untransofrmed_gradient: &[f64],
1090            _transformed_position: &mut [f64],
1091            _transformed_gradient: &mut [f64],
1092        ) -> std::result::Result<f64, Self::LogpError> {
1093            unimplemented!()
1094        }
1095
1096        fn init_from_untransformed_position(
1097            &mut self,
1098            _params: &Self::FlowParameters,
1099            _untransformed_position: &[f64],
1100            _untransformed_gradient: &mut [f64],
1101            _transformed_position: &mut [f64],
1102            _transformed_gradient: &mut [f64],
1103        ) -> std::result::Result<(f64, f64), Self::LogpError> {
1104            unimplemented!()
1105        }
1106
1107        fn init_from_transformed_position(
1108            &mut self,
1109            _params: &Self::FlowParameters,
1110            _untransformed_position: &mut [f64],
1111            _untransformed_gradient: &mut [f64],
1112            _transformed_position: &[f64],
1113            _transformed_gradient: &mut [f64],
1114        ) -> std::result::Result<(f64, f64), Self::LogpError> {
1115            unimplemented!()
1116        }
1117
1118        fn update_transformation<'b, R: rand::Rng + ?Sized>(
1119            &'b mut self,
1120            _rng: &mut R,
1121            _untransformed_positions: impl Iterator<Item = &'b [f64]>,
1122            _untransformed_gradients: impl Iterator<Item = &'b [f64]>,
1123            _untransformed_logp: impl Iterator<Item = &'b f64>,
1124            _params: &'b mut Self::FlowParameters,
1125        ) -> std::result::Result<(), Self::LogpError> {
1126            unimplemented!()
1127        }
1128
1129        fn new_transformation<R: rand::Rng + ?Sized>(
1130            &mut self,
1131            _rng: &mut R,
1132            _untransformed_position: &[f64],
1133            _untransfogmed_gradient: &[f64],
1134            _chain: u64,
1135        ) -> std::result::Result<Self::FlowParameters, Self::LogpError> {
1136            unimplemented!()
1137        }
1138
1139        fn transformation_id(
1140            &self,
1141            _params: &Self::FlowParameters,
1142        ) -> std::result::Result<i64, Self::LogpError> {
1143            unimplemented!()
1144        }
1145    }
1146
1147    pub struct CpuModel<F> {
1148        logp: F,
1149    }
1150
1151    impl<F> CpuModel<F> {
1152        pub fn new(logp: F) -> Self {
1153            Self { logp }
1154        }
1155    }
1156
1157    impl<F> Model for CpuModel<F>
1158    where
1159        F: Send + Sync + 'static,
1160        for<'a> &'a F: CpuLogpFunc,
1161    {
1162        type Math<'model> = CpuMath<&'model F>;
1163
1164        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1165            Ok(CpuMath::new(&self.logp))
1166        }
1167
1168        fn init_position<R: rand::prelude::Rng + ?Sized>(
1169            &self,
1170            _rng: &mut R,
1171            position: &mut [f64],
1172        ) -> Result<()> {
1173            position.iter_mut().for_each(|x| *x = 0.);
1174            Ok(())
1175        }
1176    }
1177}
1178
1179#[cfg(test)]
1180mod tests {
1181    use std::{
1182        sync::Arc,
1183        time::{Duration, Instant},
1184    };
1185
1186    use super::test_logps::NormalLogp;
1187    use crate::{
1188        Chain, DiagGradNutsSettings, Sampler, ZarrConfig,
1189        cpu_math::CpuMath,
1190        sample_sequentially,
1191        sampler::{Settings, test_logps::CpuModel},
1192    };
1193
1194    use anyhow::Result;
1195    use itertools::Itertools;
1196    use pretty_assertions::assert_eq;
1197    use rand::{SeedableRng, rngs::StdRng};
1198    use zarrs::storage::store::MemoryStore;
1199
1200    #[test]
1201    fn sample_chain() -> Result<()> {
1202        let logp = NormalLogp { dim: 10, mu: 0.1 };
1203        let math = CpuMath::new(&logp);
1204        let settings = DiagGradNutsSettings {
1205            num_tune: 100,
1206            num_draws: 100,
1207            ..Default::default()
1208        };
1209        let start = vec![0.2; 10];
1210
1211        let mut rng = StdRng::seed_from_u64(42);
1212
1213        let mut chain = settings.new_chain(0, math, &mut rng);
1214
1215        let (_draw, info) = chain.draw()?;
1216        assert!(info.tuning);
1217        assert_eq!(info.draw, 0);
1218
1219        let math = CpuMath::new(&logp);
1220        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1221        let mut draws = chain.collect_vec();
1222        assert_eq!(draws.len(), 200);
1223
1224        let draw0 = draws.remove(100).unwrap();
1225        let (vals, stats) = draw0;
1226        assert_eq!(vals.len(), 10);
1227        assert_eq!(stats.chain, 1);
1228        assert_eq!(stats.draw, 100);
1229        Ok(())
1230    }
1231
1232    #[test]
1233    fn sample_parallel() -> Result<()> {
1234        let logp = NormalLogp { dim: 100, mu: 0.1 };
1235        let settings = DiagGradNutsSettings {
1236            num_tune: 100,
1237            num_draws: 100,
1238            seed: 10,
1239            ..Default::default()
1240        };
1241
1242        let model = CpuModel::new(logp.clone());
1243        let store = MemoryStore::new();
1244
1245        let zarr_config = ZarrConfig::new(Arc::new(store));
1246        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1247        sampler.pause()?;
1248        sampler.pause()?;
1249        // TODO flush trace
1250        sampler.resume()?;
1251        let (ok, _) = sampler.abort()?;
1252        if let Some(err) = ok {
1253            Err(err)?;
1254        }
1255
1256        let store = MemoryStore::new();
1257        let zarr_config = ZarrConfig::new(Arc::new(store));
1258        let model = CpuModel::new(logp.clone());
1259        let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1260        sampler.pause()?;
1261        if let (Some(err), _) = sampler.abort()? {
1262            Err(err)?;
1263        }
1264
1265        let store = MemoryStore::new();
1266        let zarr_config = ZarrConfig::new(Arc::new(store));
1267        let model = CpuModel::new(logp.clone());
1268        let start = Instant::now();
1269        let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1270
1271        let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1272            super::SamplerWaitResult::Trace(_) => {
1273                dbg!(start.elapsed());
1274                panic!("finished");
1275            }
1276            super::SamplerWaitResult::Timeout(sampler) => sampler,
1277            super::SamplerWaitResult::Err(_, _) => {
1278                panic!("error")
1279            }
1280        };
1281
1282        for _ in 0..30 {
1283            sampler.progress()?;
1284        }
1285
1286        match sampler.wait_timeout(Duration::from_secs(1)) {
1287            super::SamplerWaitResult::Trace(_) => {
1288                dbg!(start.elapsed());
1289            }
1290            super::SamplerWaitResult::Timeout(_) => {
1291                panic!("timeout")
1292            }
1293            super::SamplerWaitResult::Err(err, _) => Err(err)?,
1294        };
1295
1296        Ok(())
1297    }
1298
1299    #[test]
1300    fn sample_seq() {
1301        let logp = NormalLogp { dim: 10, mu: 0.1 };
1302        let math = CpuMath::new(&logp);
1303        let settings = DiagGradNutsSettings {
1304            num_tune: 100,
1305            num_draws: 100,
1306            ..Default::default()
1307        };
1308        let start = vec![0.2; 10];
1309
1310        let mut rng = StdRng::seed_from_u64(42);
1311
1312        let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1313        let mut draws = chain.collect_vec();
1314        assert_eq!(draws.len(), 200);
1315
1316        let draw0 = draws.remove(100).unwrap();
1317        let (vals, stats) = draw0;
1318        assert_eq!(vals.len(), 10);
1319        assert_eq!(stats.chain, 1);
1320        assert_eq!(stats.draw, 100);
1321    }
1322}