1use anyhow::{Context, Result, bail};
2use itertools::Itertools;
3use nuts_storable::{HasDims, Storable, Value};
4use rand::{Rng, SeedableRng, rngs::SmallRng};
5use rand_chacha::ChaCha8Rng;
6use rayon::{ScopeFifo, ThreadPoolBuilder};
7use serde::Serialize;
8use std::{
9 collections::HashMap,
10 fmt::Debug,
11 ops::Deref,
12 sync::{
13 Arc, Mutex,
14 mpsc::{
15 Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
16 },
17 },
18 thread::{JoinHandle, spawn},
19 time::{Duration, Instant},
20};
21
22use crate::{
23 DiagAdaptExpSettings,
24 adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
25 chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
26 euclidean_hamiltonian::EuclideanHamiltonian,
27 mass_matrix::DiagMassMatrix,
28 mass_matrix::Strategy as DiagMassMatrixStrategy,
29 mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
30 math_base::Math,
31 model::Model,
32 nuts::NutsOptions,
33 sampler_stats::{SamplerStats, StatsDims},
34 storage::{ChainStorage, StorageConfig, TraceStorage},
35 transform_adapt_strategy::{TransformAdaptation, TransformedSettings},
36 transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions},
37};
38
39pub trait Settings:
41 private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + 'static
42{
43 type Chain<M: Math>: Chain<M>;
44
45 fn new_chain<M: Math, R: Rng + ?Sized>(
46 &self,
47 chain: u64,
48 math: M,
49 rng: &mut R,
50 ) -> Self::Chain<M>;
51
52 fn hint_num_tune(&self) -> usize;
53 fn hint_num_draws(&self) -> usize;
54 fn num_chains(&self) -> usize;
55 fn seed(&self) -> u64;
56 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
57
58 fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
59 let dims = StatsDims::from(math);
60 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
61 .into_iter()
62 .map(String::from)
63 .collect()
64 }
65
66 fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
67 <M::ExpandedVector as Storable<_>>::names(math)
68 .into_iter()
69 .map(String::from)
70 .collect()
71 }
72
73 fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
74 self.stat_names(math)
75 .into_iter()
76 .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
77 .collect()
78 }
79
80 fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
81 let dims = StatsDims::from(math);
82 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
83 }
84
85 fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
86 self.data_names(math)
87 .into_iter()
88 .map(|name| (name.clone(), self.data_type(math, &name)))
89 .collect()
90 }
91 fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
92 <M::ExpandedVector as Storable<_>>::item_type(math, name)
93 }
94
95 fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
96 self.stat_names(math)
97 .into_iter()
98 .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
99 .collect()
100 }
101
102 fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
103 let dims = StatsDims::from(math);
104 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
105 .into_iter()
106 .map(String::from)
107 .collect()
108 }
109
110 fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
111 let dims = StatsDims::from(math);
112 dims.dim_sizes()
113 }
114
115 fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
116 self.data_names(math)
117 .into_iter()
118 .map(|name| (name.clone(), self.data_dims(math, &name)))
119 .collect()
120 }
121
122 fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
123 <M::ExpandedVector as Storable<_>>::dims(math, name)
124 .into_iter()
125 .map(String::from)
126 .collect()
127 }
128
129 fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
130 let dims = StatsDims::from(math);
131 dims.coords()
132 }
133}
134
135#[derive(Debug, Clone)]
136#[non_exhaustive]
137pub struct Progress {
138 pub draw: u64,
139 pub chain: u64,
140 pub diverging: bool,
141 pub tuning: bool,
142 pub step_size: f64,
143 pub num_steps: u64,
144}
145
146mod private {
147 use crate::DiagGradNutsSettings;
148
149 use super::{LowRankNutsSettings, TransformedNutsSettings};
150
151 pub trait Sealed {}
152
153 impl Sealed for DiagGradNutsSettings {}
154
155 impl Sealed for LowRankNutsSettings {}
156
157 impl Sealed for TransformedNutsSettings {}
158}
159
160#[derive(Debug, Clone, Copy, Serialize)]
162pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
163 pub num_tune: u64,
165 pub num_draws: u64,
167 pub maxdepth: u64,
170 pub mindepth: u64,
173 pub store_gradient: bool,
175 pub store_unconstrained: bool,
177 pub max_energy_error: f64,
180 pub store_divergences: bool,
182 pub adapt_options: A,
184 pub check_turning: bool,
185
186 pub num_chains: usize,
187 pub seed: u64,
188}
189
190pub type DiagGradNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
191pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
192pub type TransformedNutsSettings = NutsSettings<TransformedSettings>;
193
194impl Default for DiagGradNutsSettings {
195 fn default() -> Self {
196 Self {
197 num_tune: 400,
198 num_draws: 1000,
199 maxdepth: 10,
200 mindepth: 0,
201 max_energy_error: 1000f64,
202 store_gradient: false,
203 store_unconstrained: false,
204 store_divergences: false,
205 adapt_options: EuclideanAdaptOptions::default(),
206 check_turning: true,
207 seed: 0,
208 num_chains: 6,
209 }
210 }
211}
212
213impl Default for LowRankNutsSettings {
214 fn default() -> Self {
215 let mut vals = Self {
216 num_tune: 800,
217 num_draws: 1000,
218 maxdepth: 10,
219 mindepth: 0,
220 max_energy_error: 1000f64,
221 store_gradient: false,
222 store_unconstrained: false,
223 store_divergences: false,
224 adapt_options: EuclideanAdaptOptions::default(),
225 check_turning: true,
226 seed: 0,
227 num_chains: 6,
228 };
229 vals.adapt_options.mass_matrix_update_freq = 10;
230 vals
231 }
232}
233
234impl Default for TransformedNutsSettings {
235 fn default() -> Self {
236 Self {
237 num_tune: 1500,
238 num_draws: 1000,
239 maxdepth: 10,
240 mindepth: 0,
241 max_energy_error: 20f64,
242 store_gradient: false,
243 store_unconstrained: false,
244 store_divergences: false,
245 adapt_options: Default::default(),
246 check_turning: true,
247 seed: 0,
248 num_chains: 1,
249 }
250 }
251}
252
253type DiagGradNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, DiagMassMatrixStrategy<M>>>;
254
255type LowRankNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
256
257type TransformingNutsChain<M> = NutsChain<M, SmallRng, TransformAdaptation>;
258
259impl Settings for LowRankNutsSettings {
260 type Chain<M: Math> = LowRankNutsChain<M>;
261
262 fn new_chain<M: Math, R: Rng + ?Sized>(
263 &self,
264 chain: u64,
265 mut math: M,
266 mut rng: &mut R,
267 ) -> Self::Chain<M> {
268 let num_tune = self.num_tune;
269 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
270 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
271 let max_energy_error = self.max_energy_error;
272 let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
273
274 let options = NutsOptions {
275 maxdepth: self.maxdepth,
276 mindepth: self.mindepth,
277 store_gradient: self.store_gradient,
278 store_divergences: self.store_divergences,
279 store_unconstrained: self.store_unconstrained,
280 check_turning: self.check_turning,
281 };
282
283 let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
284
285 NutsChain::new(
286 math,
287 potential,
288 strategy,
289 options,
290 rng,
291 chain,
292 self.stats_options(),
293 )
294 }
295
296 fn hint_num_tune(&self) -> usize {
297 self.num_tune as _
298 }
299
300 fn hint_num_draws(&self) -> usize {
301 self.num_draws as _
302 }
303
304 fn num_chains(&self) -> usize {
305 self.num_chains
306 }
307
308 fn seed(&self) -> u64 {
309 self.seed
310 }
311
312 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
313 StatOptions {
314 adapt: GlobalStrategyStatsOptions {
315 mass_matrix: (),
316 step_size: (),
317 },
318 hamiltonian: (),
319 point: (),
320 }
321 }
322}
323
324impl Settings for DiagGradNutsSettings {
325 type Chain<M: Math> = DiagGradNutsChain<M>;
326
327 fn new_chain<M: Math, R: Rng + ?Sized>(
328 &self,
329 chain: u64,
330 mut math: M,
331 mut rng: &mut R,
332 ) -> Self::Chain<M> {
333 let num_tune = self.num_tune;
334 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
335 let mass_matrix = DiagMassMatrix::new(
336 &mut math,
337 self.adapt_options.mass_matrix_options.store_mass_matrix,
338 );
339 let max_energy_error = self.max_energy_error;
340 let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
341
342 let options = NutsOptions {
343 maxdepth: self.maxdepth,
344 mindepth: self.mindepth,
345 store_gradient: self.store_gradient,
346 store_divergences: self.store_divergences,
347 store_unconstrained: self.store_unconstrained,
348 check_turning: self.check_turning,
349 };
350
351 let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
352
353 NutsChain::new(
354 math,
355 potential,
356 strategy,
357 options,
358 rng,
359 chain,
360 self.stats_options(),
361 )
362 }
363
364 fn hint_num_tune(&self) -> usize {
365 self.num_tune as _
366 }
367
368 fn hint_num_draws(&self) -> usize {
369 self.num_draws as _
370 }
371
372 fn num_chains(&self) -> usize {
373 self.num_chains
374 }
375
376 fn seed(&self) -> u64 {
377 self.seed
378 }
379
380 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
381 StatOptions {
382 adapt: GlobalStrategyStatsOptions {
383 mass_matrix: (),
384 step_size: (),
385 },
386 hamiltonian: (),
387 point: (),
388 }
389 }
390}
391
392impl Settings for TransformedNutsSettings {
393 type Chain<M: Math> = TransformingNutsChain<M>;
394
395 fn new_chain<M: Math, R: Rng + ?Sized>(
396 &self,
397 chain: u64,
398 mut math: M,
399 mut rng: &mut R,
400 ) -> Self::Chain<M> {
401 let num_tune = self.num_tune;
402 let max_energy_error = self.max_energy_error;
403
404 let strategy = TransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
405 let hamiltonian = TransformedHamiltonian::new(&mut math, max_energy_error);
406
407 let options = NutsOptions {
408 maxdepth: self.maxdepth,
409 mindepth: self.mindepth,
410 store_gradient: self.store_gradient,
411 store_divergences: self.store_divergences,
412 store_unconstrained: self.store_unconstrained,
413 check_turning: self.check_turning,
414 };
415
416 let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
417 NutsChain::new(
418 math,
419 hamiltonian,
420 strategy,
421 options,
422 rng,
423 chain,
424 self.stats_options(),
425 )
426 }
427
428 fn hint_num_tune(&self) -> usize {
429 self.num_tune as _
430 }
431
432 fn hint_num_draws(&self) -> usize {
433 self.num_draws as _
434 }
435
436 fn num_chains(&self) -> usize {
437 self.num_chains
438 }
439
440 fn seed(&self) -> u64 {
441 self.seed
442 }
443
444 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
445 let point = TransformedPointStatsOptions {
447 store_transformed: self.store_unconstrained,
448 };
449 StatOptions {
450 adapt: (),
451 hamiltonian: (),
452 point,
453 }
454 }
455}
456
457pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
458 math: M,
459 settings: DiagGradNutsSettings,
460 start: &[f64],
461 draws: u64,
462 chain: u64,
463 rng: &mut R,
464) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
465 let mut sampler = settings.new_chain(chain, math, rng);
466 sampler.set_position(start)?;
467 Ok((0..draws).map(move |_| sampler.draw()))
468}
469
470#[non_exhaustive]
471#[derive(Clone, Debug)]
472pub struct ChainProgress {
473 pub finished_draws: usize,
474 pub total_draws: usize,
475 pub divergences: usize,
476 pub tuning: bool,
477 pub started: bool,
478 pub latest_num_steps: usize,
479 pub total_num_steps: usize,
480 pub step_size: f64,
481 pub runtime: Duration,
482 pub divergent_draws: Vec<usize>,
483}
484
485impl ChainProgress {
486 fn new(total: usize) -> Self {
487 Self {
488 finished_draws: 0,
489 total_draws: total,
490 divergences: 0,
491 tuning: true,
492 started: false,
493 latest_num_steps: 0,
494 step_size: 0f64,
495 total_num_steps: 0,
496 runtime: Duration::ZERO,
497 divergent_draws: Vec::new(),
498 }
499 }
500
501 fn update(&mut self, stats: &Progress, draw_duration: Duration) {
502 if stats.diverging & !stats.tuning {
503 self.divergences += 1;
504 self.divergent_draws.push(self.finished_draws);
505 }
506 self.finished_draws += 1;
507 self.tuning = stats.tuning;
508
509 self.latest_num_steps = stats.num_steps as usize;
510 self.total_num_steps += stats.num_steps as usize;
511 self.step_size = stats.step_size;
512 self.runtime += draw_duration;
513 }
514}
515
516enum ChainCommand {
517 Resume,
518 Pause,
519}
520
521struct ChainProcess<T>
522where
523 T: TraceStorage,
524{
525 stop_marker: Sender<ChainCommand>,
526 trace: Arc<Mutex<Option<T::ChainStorage>>>,
527 progress: Arc<Mutex<ChainProgress>>,
528}
529
530impl<T: TraceStorage> ChainProcess<T> {
531 fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
532 let finalized_chain_traces = chains
533 .into_iter()
534 .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
535 .map(|chain| chain.finalize())
536 .collect_vec();
537 trace.finalize(finalized_chain_traces)
538 }
539
540 fn progress(&self) -> ChainProgress {
541 self.progress.lock().expect("Poisoned lock").clone()
542 }
543
544 fn resume(&self) -> Result<()> {
545 self.stop_marker.send(ChainCommand::Resume)?;
546 Ok(())
547 }
548
549 fn pause(&self) -> Result<()> {
550 self.stop_marker.send(ChainCommand::Pause)?;
551 Ok(())
552 }
553
554 fn start<'model, M: Model, S: Settings>(
555 model: &'model M,
556 chain_trace: T::ChainStorage,
557 chain_id: u64,
558 seed: u64,
559 settings: &'model S,
560 scope: &ScopeFifo<'model>,
561 results: Sender<Result<()>>,
562 ) -> Result<Self> {
563 let (stop_marker_tx, stop_marker_rx) = channel();
564
565 let mut rng = ChaCha8Rng::seed_from_u64(seed);
566 rng.set_stream(chain_id + 1);
567
568 let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
569 let progress = Arc::new(Mutex::new(ChainProgress::new(
570 settings.hint_num_draws() + settings.hint_num_tune(),
571 )));
572
573 let trace_inner = chain_trace.clone();
574 let progress_inner = progress.clone();
575
576 scope.spawn_fifo(move |_| {
577 let chain_trace = trace_inner;
578 let progress = progress_inner;
579
580 let mut sample = move || {
581 let logp = model
582 .math(&mut rng)
583 .context("Failed to create model density")?;
584 let dim = logp.dim();
585
586 let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
587
588 progress.lock().expect("Poisoned mutex").started = true;
589
590 let mut initval = vec![0f64; dim];
591 let mut error = None;
593 for _ in 0..500 {
594 model
595 .init_position(&mut rng, &mut initval)
596 .context("Failed to generate a new initial position")?;
597 if let Err(err) = sampler.set_position(&initval) {
598 error = Some(err);
599 continue;
600 }
601 error = None;
602 break;
603 }
604
605 if let Some(error) = error {
606 return Err(error.context("All initialization points failed"));
607 }
608
609 let draws = settings.hint_num_tune() + settings.hint_num_draws();
610
611 let mut msg = stop_marker_rx.try_recv();
612 let mut draw = 0;
613 loop {
614 match msg {
615 Err(TryRecvError::Disconnected) => {
617 break;
618 }
619 Err(TryRecvError::Empty) => {}
620 Ok(ChainCommand::Pause) => {
621 msg = stop_marker_rx.recv().map_err(|e| e.into());
622 continue;
623 }
624 Ok(ChainCommand::Resume) => {}
625 }
626
627 let now = Instant::now();
628 let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
630
631 let mut guard = chain_trace
632 .lock()
633 .expect("Could not unlock trace lock. Poisoned mutex");
634
635 let Some(trace_val) = guard.as_mut() else {
636 break;
638 };
639 progress
640 .lock()
641 .expect("Poisoned mutex")
642 .update(&info, now.elapsed());
643
644 let math = sampler.math();
645 let dims = StatsDims::from(math.deref());
646 trace_val.record_sample(
647 settings,
648 stats.get_all(&dims),
649 draw_data.get_all(math.deref()),
650 &info,
651 )?;
652
653 draw += 1;
654 if draw == draws {
655 break;
656 }
657
658 msg = stop_marker_rx.try_recv();
659 }
660 Ok(())
661 };
662
663 let result = sample();
664
665 let _ = results.send(result);
668 drop(results);
669 });
670
671 Ok(Self {
672 trace: chain_trace,
673 stop_marker: stop_marker_tx,
674 progress,
675 })
676 }
677
678 fn flush(&self) -> Result<()> {
679 self.trace
680 .lock()
681 .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
682 .context("Could not flush trace")?
683 .as_mut()
684 .map(|v| v.flush())
685 .transpose()?;
686 Ok(())
687 }
688}
689
690#[derive(Debug)]
691enum SamplerCommand {
692 Pause,
693 Continue,
694 Progress,
695 Flush,
696 Inspect,
697}
698
699enum SamplerResponse<T: Send + 'static> {
700 Ok(),
701 Progress(Box<[ChainProgress]>),
702 Inspect(T),
703}
704
705pub enum SamplerWaitResult<F: Send + 'static> {
706 Trace(F),
707 Timeout(Sampler<F>),
708 Err(anyhow::Error, Option<F>),
709}
710
711pub struct Sampler<F: Send + 'static> {
712 main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
713 commands: SyncSender<SamplerCommand>,
714 responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
715 results: Receiver<Result<()>>,
716}
717
718pub struct ProgressCallback {
719 pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
720 pub rate: Duration,
721}
722
723impl<F: Send + 'static> Sampler<F> {
724 pub fn new<M, S, C, T>(
725 model: M,
726 settings: S,
727 trace_config: C,
728 num_cores: usize,
729 callback: Option<ProgressCallback>,
730 ) -> Result<Self>
731 where
732 S: Settings,
733 C: StorageConfig<Storage = T>,
734 M: Model,
735 T: TraceStorage<Finalized = F>,
736 {
737 let (commands_tx, commands_rx) = sync_channel(0);
738 let (responses_tx, responses_rx) = sync_channel(0);
739 let (results_tx, results_rx) = channel();
740
741 let main_thread = spawn(move || {
742 let pool = ThreadPoolBuilder::new()
743 .num_threads(num_cores + 1) .thread_name(|i| format!("nutpie-worker-{i}"))
745 .build()
746 .context("Could not start thread pool")?;
747
748 let settings_ref = &settings;
749 let model_ref = &model;
750 let mut callback = callback;
751
752 pool.scope_fifo(move |scope| {
753 let results = results_tx;
754 let mut chains = Vec::with_capacity(settings.num_chains());
755
756 let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
757 rng.set_stream(0);
758
759 let math = model_ref
760 .math(&mut rng)
761 .context("Could not create model density")?;
762 let trace = trace_config
763 .new_trace(settings_ref, &math)
764 .context("Could not create trace object")?;
765 drop(math);
766
767 for chain_id in 0..settings.num_chains() {
768 let chain_trace_val = trace
769 .initialize_trace_for_chain(chain_id as u64)
770 .context("Failed to create trace object")?;
771 let chain = ChainProcess::start(
772 model_ref,
773 chain_trace_val,
774 chain_id as u64,
775 settings.seed(),
776 settings_ref,
777 scope,
778 results.clone(),
779 );
780 chains.push(chain);
781 }
782 drop(results);
783
784 let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
785 if let Some(error) = errors.into_iter().next() {
786 let _ = ChainProcess::finalize_many(trace, chains);
787 return Err(error).context("Could not start chains");
788 }
789
790 let mut main_loop = || {
791 let start_time = Instant::now();
792 let mut pause_start = Instant::now();
793 let mut pause_time = Duration::ZERO;
794
795 let mut progress_rate = Duration::MAX;
796 if let Some(ProgressCallback { callback, rate }) = &mut callback {
797 let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
798 callback(start_time.elapsed(), progress.into());
799 progress_rate = *rate;
800 }
801 let mut last_progress = Instant::now();
802 let mut is_paused = false;
803
804 loop {
805 let timeout = progress_rate.checked_sub(last_progress.elapsed());
806 let timeout = timeout.unwrap_or_else(|| {
807 if let Some(ProgressCallback { callback, .. }) = &mut callback {
808 let progress =
809 chains.iter().map(|chain| chain.progress()).collect_vec();
810 let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
811 if is_paused {
812 elapsed = elapsed.saturating_sub(pause_start.elapsed());
813 }
814 callback(elapsed, progress.into());
815 }
816 last_progress = Instant::now();
817 progress_rate
818 });
819
820 match commands_rx.recv_timeout(timeout) {
822 Ok(SamplerCommand::Pause) => {
823 for chain in chains.iter() {
824 let _ = chain.pause();
827 }
828 if !is_paused {
829 pause_start = Instant::now();
830 }
831 is_paused = true;
832 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
833 anyhow::anyhow!(
834 "Could not send pause response to controller thread: {e}"
835 )
836 })?;
837 }
838 Ok(SamplerCommand::Continue) => {
839 for chain in chains.iter() {
840 let _ = chain.resume();
843 }
844 pause_time += pause_start.elapsed();
845 is_paused = false;
846 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
847 anyhow::anyhow!(
848 "Could not send continue response to controller thread: {e}"
849 )
850 })?;
851 }
852 Ok(SamplerCommand::Progress) => {
853 let progress =
854 chains.iter().map(|chain| chain.progress()).collect_vec();
855 responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
856 anyhow::anyhow!(
857 "Could not send progress response to controller thread: {e}"
858 )
859 })?;
860 }
861 Ok(SamplerCommand::Inspect) => {
862 let traces = chains
863 .iter()
864 .map(|chain| {
865 chain
866 .trace
867 .lock()
868 .expect("Poisoned lock")
869 .as_ref()
870 .map(|v| v.inspect())
871 })
872 .flatten()
873 .collect_vec();
874 let finalized_trace = trace.inspect(traces)?;
875 responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
876 anyhow::anyhow!(
877 "Could not send inspect response to controller thread: {e}"
878 )
879 })?;
880 }
881 Ok(SamplerCommand::Flush) => {
882 for chain in chains.iter() {
883 chain.flush()?;
884 }
885 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
886 anyhow::anyhow!(
887 "Could not send flush response to controller thread: {e}"
888 )
889 })?;
890 }
891 Err(RecvTimeoutError::Timeout) => {}
892 Err(RecvTimeoutError::Disconnected) => {
893 if let Some(ProgressCallback { callback, .. }) = &mut callback {
894 let progress =
895 chains.iter().map(|chain| chain.progress()).collect_vec();
896 let mut elapsed =
897 start_time.elapsed().saturating_sub(pause_time);
898 if is_paused {
899 elapsed = elapsed.saturating_sub(pause_start.elapsed());
900 }
901 callback(elapsed, progress.into());
902 }
903 return Ok(());
904 }
905 };
906 }
907 };
908 let result: Result<()> = main_loop();
909 let output = ChainProcess::finalize_many(trace, chains)?;
911
912 result?;
913 Ok(output)
914 })
915 });
916
917 Ok(Self {
918 main_thread,
919 commands: commands_tx,
920 responses: responses_rx,
921 results: results_rx,
922 })
923 }
924
925 pub fn pause(&mut self) -> Result<()> {
926 self.commands
927 .send(SamplerCommand::Pause)
928 .context("Could not send pause command to controller thread")?;
929 let response = self
930 .responses
931 .recv()
932 .context("Could not recieve pause response from controller thread")?;
933 let SamplerResponse::Ok() = response else {
934 bail!("Got invalid response from sample controller thread");
935 };
936 Ok(())
937 }
938
939 pub fn resume(&mut self) -> Result<()> {
940 self.commands.send(SamplerCommand::Continue)?;
941 let response = self.responses.recv()?;
942 let SamplerResponse::Ok() = response else {
943 bail!("Got invalid response from sample controller thread");
944 };
945 Ok(())
946 }
947
948 pub fn flush(&mut self) -> Result<()> {
949 self.commands.send(SamplerCommand::Flush)?;
950 let response = self
951 .responses
952 .recv()
953 .context("Could not recieve flush response from controller thread")?;
954 let SamplerResponse::Ok() = response else {
955 bail!("Got invalid response from sample controller thread");
956 };
957 Ok(())
958 }
959
960 pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
961 self.commands.send(SamplerCommand::Inspect)?;
962 let response = self
963 .responses
964 .recv()
965 .context("Could not recieve inspect response from controller thread")?;
966 let SamplerResponse::Inspect(trace) = response else {
967 bail!("Got invalid response from sample controller thread");
968 };
969 Ok(trace)
970 }
971
972 pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
973 drop(self.commands);
974 let result = self.main_thread.join();
975 match result {
976 Err(payload) => std::panic::resume_unwind(payload),
977 Ok(Ok(val)) => Ok(val),
978 Ok(Err(err)) => Err(err),
979 }
980 }
981
982 pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
983 let start = Instant::now();
984 let mut remaining = Some(timeout);
985 while remaining.is_some() {
986 match self.results.recv_timeout(timeout) {
987 Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
988 Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
989 Err(RecvTimeoutError::Disconnected) => match self.abort() {
990 Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
991 Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
992 Err(err) => return SamplerWaitResult::Err(err, None),
993 },
994 Err(RecvTimeoutError::Timeout) => break,
995 }
996 }
997 SamplerWaitResult::Timeout(self)
998 }
999
1000 pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1001 self.commands.send(SamplerCommand::Progress)?;
1002 let response = self.responses.recv()?;
1003 let SamplerResponse::Progress(progress) = response else {
1004 bail!("Got invalid response from sample controller thread");
1005 };
1006 Ok(progress)
1007 }
1008}
1009
1010#[cfg(test)]
1011pub mod test_logps {
1012
1013 use std::collections::HashMap;
1014
1015 use crate::{
1016 Model,
1017 cpu_math::{CpuLogpFunc, CpuMath},
1018 math_base::LogpError,
1019 };
1020 use anyhow::Result;
1021 use nuts_storable::HasDims;
1022 use rand::Rng;
1023 use thiserror::Error;
1024
1025 #[derive(Clone, Debug)]
1026 pub struct NormalLogp {
1027 pub dim: usize,
1028 pub mu: f64,
1029 }
1030
1031 #[derive(Error, Debug)]
1032 pub enum NormalLogpError {}
1033
1034 impl LogpError for NormalLogpError {
1035 fn is_recoverable(&self) -> bool {
1036 false
1037 }
1038 }
1039
1040 impl HasDims for &NormalLogp {
1041 fn dim_sizes(&self) -> HashMap<String, u64> {
1042 vec![
1043 ("unconstrained_parameter".to_string(), self.dim as u64),
1044 ("dim".to_string(), self.dim as u64),
1045 ]
1046 .into_iter()
1047 .collect()
1048 }
1049 }
1050
1051 impl CpuLogpFunc for &NormalLogp {
1052 type LogpError = NormalLogpError;
1053 type FlowParameters = ();
1054 type ExpandedVector = Vec<f64>;
1055
1056 fn dim(&self) -> usize {
1057 self.dim
1058 }
1059
1060 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
1061 let n = position.len();
1062 assert!(gradient.len() == n);
1063
1064 let mut logp = 0f64;
1065 for (p, g) in position.iter().zip(gradient.iter_mut()) {
1066 let val = self.mu - p;
1067 logp -= val * val / 2.;
1068 *g = val;
1069 }
1070
1071 Ok(logp)
1072 }
1073
1074 fn expand_vector<R>(
1075 &mut self,
1076 _rng: &mut R,
1077 array: &[f64],
1078 ) -> std::result::Result<Self::ExpandedVector, crate::cpu_math::CpuMathError>
1079 where
1080 R: rand::Rng + ?Sized,
1081 {
1082 Ok(array.to_vec())
1083 }
1084
1085 fn inv_transform_normalize(
1086 &mut self,
1087 _params: &Self::FlowParameters,
1088 _untransformed_position: &[f64],
1089 _untransofrmed_gradient: &[f64],
1090 _transformed_position: &mut [f64],
1091 _transformed_gradient: &mut [f64],
1092 ) -> std::result::Result<f64, Self::LogpError> {
1093 unimplemented!()
1094 }
1095
1096 fn init_from_untransformed_position(
1097 &mut self,
1098 _params: &Self::FlowParameters,
1099 _untransformed_position: &[f64],
1100 _untransformed_gradient: &mut [f64],
1101 _transformed_position: &mut [f64],
1102 _transformed_gradient: &mut [f64],
1103 ) -> std::result::Result<(f64, f64), Self::LogpError> {
1104 unimplemented!()
1105 }
1106
1107 fn init_from_transformed_position(
1108 &mut self,
1109 _params: &Self::FlowParameters,
1110 _untransformed_position: &mut [f64],
1111 _untransformed_gradient: &mut [f64],
1112 _transformed_position: &[f64],
1113 _transformed_gradient: &mut [f64],
1114 ) -> std::result::Result<(f64, f64), Self::LogpError> {
1115 unimplemented!()
1116 }
1117
1118 fn update_transformation<'b, R: rand::Rng + ?Sized>(
1119 &'b mut self,
1120 _rng: &mut R,
1121 _untransformed_positions: impl Iterator<Item = &'b [f64]>,
1122 _untransformed_gradients: impl Iterator<Item = &'b [f64]>,
1123 _untransformed_logp: impl Iterator<Item = &'b f64>,
1124 _params: &'b mut Self::FlowParameters,
1125 ) -> std::result::Result<(), Self::LogpError> {
1126 unimplemented!()
1127 }
1128
1129 fn new_transformation<R: rand::Rng + ?Sized>(
1130 &mut self,
1131 _rng: &mut R,
1132 _untransformed_position: &[f64],
1133 _untransfogmed_gradient: &[f64],
1134 _chain: u64,
1135 ) -> std::result::Result<Self::FlowParameters, Self::LogpError> {
1136 unimplemented!()
1137 }
1138
1139 fn transformation_id(
1140 &self,
1141 _params: &Self::FlowParameters,
1142 ) -> std::result::Result<i64, Self::LogpError> {
1143 unimplemented!()
1144 }
1145 }
1146
1147 pub struct CpuModel<F> {
1148 logp: F,
1149 }
1150
1151 impl<F> CpuModel<F> {
1152 pub fn new(logp: F) -> Self {
1153 Self { logp }
1154 }
1155 }
1156
1157 impl<F> Model for CpuModel<F>
1158 where
1159 F: Send + Sync + 'static,
1160 for<'a> &'a F: CpuLogpFunc,
1161 {
1162 type Math<'model> = CpuMath<&'model F>;
1163
1164 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1165 Ok(CpuMath::new(&self.logp))
1166 }
1167
1168 fn init_position<R: rand::prelude::Rng + ?Sized>(
1169 &self,
1170 _rng: &mut R,
1171 position: &mut [f64],
1172 ) -> Result<()> {
1173 position.iter_mut().for_each(|x| *x = 0.);
1174 Ok(())
1175 }
1176 }
1177}
1178
1179#[cfg(test)]
1180mod tests {
1181 use std::{
1182 sync::Arc,
1183 time::{Duration, Instant},
1184 };
1185
1186 use super::test_logps::NormalLogp;
1187 use crate::{
1188 Chain, DiagGradNutsSettings, Sampler, ZarrConfig,
1189 cpu_math::CpuMath,
1190 sample_sequentially,
1191 sampler::{Settings, test_logps::CpuModel},
1192 };
1193
1194 use anyhow::Result;
1195 use itertools::Itertools;
1196 use pretty_assertions::assert_eq;
1197 use rand::{SeedableRng, rngs::StdRng};
1198 use zarrs::storage::store::MemoryStore;
1199
1200 #[test]
1201 fn sample_chain() -> Result<()> {
1202 let logp = NormalLogp { dim: 10, mu: 0.1 };
1203 let math = CpuMath::new(&logp);
1204 let settings = DiagGradNutsSettings {
1205 num_tune: 100,
1206 num_draws: 100,
1207 ..Default::default()
1208 };
1209 let start = vec![0.2; 10];
1210
1211 let mut rng = StdRng::seed_from_u64(42);
1212
1213 let mut chain = settings.new_chain(0, math, &mut rng);
1214
1215 let (_draw, info) = chain.draw()?;
1216 assert!(info.tuning);
1217 assert_eq!(info.draw, 0);
1218
1219 let math = CpuMath::new(&logp);
1220 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1221 let mut draws = chain.collect_vec();
1222 assert_eq!(draws.len(), 200);
1223
1224 let draw0 = draws.remove(100).unwrap();
1225 let (vals, stats) = draw0;
1226 assert_eq!(vals.len(), 10);
1227 assert_eq!(stats.chain, 1);
1228 assert_eq!(stats.draw, 100);
1229 Ok(())
1230 }
1231
1232 #[test]
1233 fn sample_parallel() -> Result<()> {
1234 let logp = NormalLogp { dim: 100, mu: 0.1 };
1235 let settings = DiagGradNutsSettings {
1236 num_tune: 100,
1237 num_draws: 100,
1238 seed: 10,
1239 ..Default::default()
1240 };
1241
1242 let model = CpuModel::new(logp.clone());
1243 let store = MemoryStore::new();
1244
1245 let zarr_config = ZarrConfig::new(Arc::new(store));
1246 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1247 sampler.pause()?;
1248 sampler.pause()?;
1249 sampler.resume()?;
1251 let (ok, _) = sampler.abort()?;
1252 if let Some(err) = ok {
1253 Err(err)?;
1254 }
1255
1256 let store = MemoryStore::new();
1257 let zarr_config = ZarrConfig::new(Arc::new(store));
1258 let model = CpuModel::new(logp.clone());
1259 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1260 sampler.pause()?;
1261 if let (Some(err), _) = sampler.abort()? {
1262 Err(err)?;
1263 }
1264
1265 let store = MemoryStore::new();
1266 let zarr_config = ZarrConfig::new(Arc::new(store));
1267 let model = CpuModel::new(logp.clone());
1268 let start = Instant::now();
1269 let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1270
1271 let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1272 super::SamplerWaitResult::Trace(_) => {
1273 dbg!(start.elapsed());
1274 panic!("finished");
1275 }
1276 super::SamplerWaitResult::Timeout(sampler) => sampler,
1277 super::SamplerWaitResult::Err(_, _) => {
1278 panic!("error")
1279 }
1280 };
1281
1282 for _ in 0..30 {
1283 sampler.progress()?;
1284 }
1285
1286 match sampler.wait_timeout(Duration::from_secs(1)) {
1287 super::SamplerWaitResult::Trace(_) => {
1288 dbg!(start.elapsed());
1289 }
1290 super::SamplerWaitResult::Timeout(_) => {
1291 panic!("timeout")
1292 }
1293 super::SamplerWaitResult::Err(err, _) => Err(err)?,
1294 };
1295
1296 Ok(())
1297 }
1298
1299 #[test]
1300 fn sample_seq() {
1301 let logp = NormalLogp { dim: 10, mu: 0.1 };
1302 let math = CpuMath::new(&logp);
1303 let settings = DiagGradNutsSettings {
1304 num_tune: 100,
1305 num_draws: 100,
1306 ..Default::default()
1307 };
1308 let start = vec![0.2; 10];
1309
1310 let mut rng = StdRng::seed_from_u64(42);
1311
1312 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1313 let mut draws = chain.collect_vec();
1314 assert_eq!(draws.len(), 200);
1315
1316 let draw0 = draws.remove(100).unwrap();
1317 let (vals, stats) = draw0;
1318 assert_eq!(vals.len(), 10);
1319 assert_eq!(stats.chain, 1);
1320 assert_eq!(stats.draw, 100);
1321 }
1322}