1use anyhow::{Context, Result, bail};
2use itertools::Itertools;
3use nuts_storable::{HasDims, Storable, Value};
4use rand::{Rng, SeedableRng, rngs::ChaCha8Rng, rngs::SmallRng};
5use rayon::{ScopeFifo, ThreadPoolBuilder};
6use serde::Serialize;
7use std::{
8 collections::HashMap,
9 fmt::Debug,
10 ops::Deref,
11 sync::{
12 Arc, Mutex,
13 mpsc::{
14 Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
15 },
16 },
17 thread::{JoinHandle, spawn},
18 time::{Duration, Instant},
19};
20
21use crate::{
22 DiagAdaptExpSettings,
23 adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
24 chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
25 euclidean_hamiltonian::EuclideanHamiltonian,
26 mass_matrix::DiagMassMatrix,
27 mass_matrix::Strategy as DiagMassMatrixStrategy,
28 mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
29 math_base::Math,
30 model::Model,
31 nuts::NutsOptions,
32 sampler_stats::{SamplerStats, StatsDims},
33 storage::{ChainStorage, StorageConfig, TraceStorage},
34 transform_adapt_strategy::{TransformAdaptation, TransformedSettings},
35 transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions},
36};
37
38pub trait Settings:
40 private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + 'static
41{
42 type Chain<M: Math>: Chain<M>;
43
44 fn new_chain<M: Math, R: Rng + ?Sized>(
45 &self,
46 chain: u64,
47 math: M,
48 rng: &mut R,
49 ) -> Self::Chain<M>;
50
51 fn hint_num_tune(&self) -> usize;
52 fn hint_num_draws(&self) -> usize;
53 fn num_chains(&self) -> usize;
54 fn seed(&self) -> u64;
55 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
56
57 fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
58 let dims = StatsDims::from(math);
59 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
60 .into_iter()
61 .map(String::from)
62 .collect()
63 }
64
65 fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
66 <M::ExpandedVector as Storable<_>>::names(math)
67 .into_iter()
68 .map(String::from)
69 .collect()
70 }
71
72 fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
73 self.stat_names(math)
74 .into_iter()
75 .map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
76 .collect()
77 }
78
79 fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
80 let dims = StatsDims::from(math);
81 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
82 }
83
84 fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
85 self.data_names(math)
86 .into_iter()
87 .map(|name| (name.clone(), self.data_type(math, &name)))
88 .collect()
89 }
90 fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
91 <M::ExpandedVector as Storable<_>>::item_type(math, name)
92 }
93
94 fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
95 self.stat_names(math)
96 .into_iter()
97 .map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
98 .collect()
99 }
100
101 fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
102 let dims = StatsDims::from(math);
103 <<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
104 .into_iter()
105 .map(String::from)
106 .collect()
107 }
108
109 fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
110 let dims = StatsDims::from(math);
111 dims.dim_sizes()
112 }
113
114 fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
115 self.data_names(math)
116 .into_iter()
117 .map(|name| (name.clone(), self.data_dims(math, &name)))
118 .collect()
119 }
120
121 fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
122 <M::ExpandedVector as Storable<_>>::dims(math, name)
123 .into_iter()
124 .map(String::from)
125 .collect()
126 }
127
128 fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
129 let dims = StatsDims::from(math);
130 dims.coords()
131 }
132}
133
134#[derive(Debug, Clone)]
135#[non_exhaustive]
136pub struct Progress {
137 pub draw: u64,
138 pub chain: u64,
139 pub diverging: bool,
140 pub tuning: bool,
141 pub step_size: f64,
142 pub num_steps: u64,
143}
144
145mod private {
146 use crate::DiagGradNutsSettings;
147
148 use super::{LowRankNutsSettings, TransformedNutsSettings};
149
150 pub trait Sealed {}
151
152 impl Sealed for DiagGradNutsSettings {}
153
154 impl Sealed for LowRankNutsSettings {}
155
156 impl Sealed for TransformedNutsSettings {}
157}
158
159#[derive(Debug, Clone, Copy, Serialize)]
161pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
162 pub num_tune: u64,
164 pub num_draws: u64,
166 pub maxdepth: u64,
169 pub mindepth: u64,
172 pub store_gradient: bool,
174 pub store_unconstrained: bool,
176 pub max_energy_error: f64,
179 pub store_divergences: bool,
181 pub adapt_options: A,
183 pub check_turning: bool,
184
185 pub num_chains: usize,
186 pub seed: u64,
187}
188
189pub type DiagGradNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
190pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
191pub type TransformedNutsSettings = NutsSettings<TransformedSettings>;
192
193impl Default for DiagGradNutsSettings {
194 fn default() -> Self {
195 Self {
196 num_tune: 400,
197 num_draws: 1000,
198 maxdepth: 10,
199 mindepth: 0,
200 max_energy_error: 1000f64,
201 store_gradient: false,
202 store_unconstrained: false,
203 store_divergences: false,
204 adapt_options: EuclideanAdaptOptions::default(),
205 check_turning: true,
206 seed: 0,
207 num_chains: 6,
208 }
209 }
210}
211
212impl Default for LowRankNutsSettings {
213 fn default() -> Self {
214 let mut vals = Self {
215 num_tune: 800,
216 num_draws: 1000,
217 maxdepth: 10,
218 mindepth: 0,
219 max_energy_error: 1000f64,
220 store_gradient: false,
221 store_unconstrained: false,
222 store_divergences: false,
223 adapt_options: EuclideanAdaptOptions::default(),
224 check_turning: true,
225 seed: 0,
226 num_chains: 6,
227 };
228 vals.adapt_options.mass_matrix_update_freq = 10;
229 vals
230 }
231}
232
233impl Default for TransformedNutsSettings {
234 fn default() -> Self {
235 Self {
236 num_tune: 1500,
237 num_draws: 1000,
238 maxdepth: 10,
239 mindepth: 0,
240 max_energy_error: 20f64,
241 store_gradient: false,
242 store_unconstrained: false,
243 store_divergences: false,
244 adapt_options: Default::default(),
245 check_turning: true,
246 seed: 0,
247 num_chains: 1,
248 }
249 }
250}
251
252type DiagGradNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, DiagMassMatrixStrategy<M>>>;
253
254type LowRankNutsChain<M> = NutsChain<M, SmallRng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
255
256type TransformingNutsChain<M> = NutsChain<M, SmallRng, TransformAdaptation>;
257
258impl Settings for LowRankNutsSettings {
259 type Chain<M: Math> = LowRankNutsChain<M>;
260
261 fn new_chain<M: Math, R: Rng + ?Sized>(
262 &self,
263 chain: u64,
264 mut math: M,
265 mut rng: &mut R,
266 ) -> Self::Chain<M> {
267 let num_tune = self.num_tune;
268 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
269 let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
270 let max_energy_error = self.max_energy_error;
271 let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
272
273 let options = NutsOptions {
274 maxdepth: self.maxdepth,
275 mindepth: self.mindepth,
276 store_gradient: self.store_gradient,
277 store_divergences: self.store_divergences,
278 store_unconstrained: self.store_unconstrained,
279 check_turning: self.check_turning,
280 };
281
282 let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
283
284 NutsChain::new(
285 math,
286 potential,
287 strategy,
288 options,
289 rng,
290 chain,
291 self.stats_options(),
292 )
293 }
294
295 fn hint_num_tune(&self) -> usize {
296 self.num_tune as _
297 }
298
299 fn hint_num_draws(&self) -> usize {
300 self.num_draws as _
301 }
302
303 fn num_chains(&self) -> usize {
304 self.num_chains
305 }
306
307 fn seed(&self) -> u64 {
308 self.seed
309 }
310
311 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
312 StatOptions {
313 adapt: GlobalStrategyStatsOptions {
314 mass_matrix: (),
315 step_size: (),
316 },
317 hamiltonian: (),
318 point: (),
319 }
320 }
321}
322
323impl Settings for DiagGradNutsSettings {
324 type Chain<M: Math> = DiagGradNutsChain<M>;
325
326 fn new_chain<M: Math, R: Rng + ?Sized>(
327 &self,
328 chain: u64,
329 mut math: M,
330 mut rng: &mut R,
331 ) -> Self::Chain<M> {
332 let num_tune = self.num_tune;
333 let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
334 let mass_matrix = DiagMassMatrix::new(
335 &mut math,
336 self.adapt_options.mass_matrix_options.store_mass_matrix,
337 );
338 let max_energy_error = self.max_energy_error;
339 let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64);
340
341 let options = NutsOptions {
342 maxdepth: self.maxdepth,
343 mindepth: self.mindepth,
344 store_gradient: self.store_gradient,
345 store_divergences: self.store_divergences,
346 store_unconstrained: self.store_unconstrained,
347 check_turning: self.check_turning,
348 };
349
350 let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
351
352 NutsChain::new(
353 math,
354 potential,
355 strategy,
356 options,
357 rng,
358 chain,
359 self.stats_options(),
360 )
361 }
362
363 fn hint_num_tune(&self) -> usize {
364 self.num_tune as _
365 }
366
367 fn hint_num_draws(&self) -> usize {
368 self.num_draws as _
369 }
370
371 fn num_chains(&self) -> usize {
372 self.num_chains
373 }
374
375 fn seed(&self) -> u64 {
376 self.seed
377 }
378
379 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
380 StatOptions {
381 adapt: GlobalStrategyStatsOptions {
382 mass_matrix: (),
383 step_size: (),
384 },
385 hamiltonian: (),
386 point: (),
387 }
388 }
389}
390
391impl Settings for TransformedNutsSettings {
392 type Chain<M: Math> = TransformingNutsChain<M>;
393
394 fn new_chain<M: Math, R: Rng + ?Sized>(
395 &self,
396 chain: u64,
397 mut math: M,
398 mut rng: &mut R,
399 ) -> Self::Chain<M> {
400 let num_tune = self.num_tune;
401 let max_energy_error = self.max_energy_error;
402
403 let strategy = TransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
404 let hamiltonian = TransformedHamiltonian::new(&mut math, max_energy_error);
405
406 let options = NutsOptions {
407 maxdepth: self.maxdepth,
408 mindepth: self.mindepth,
409 store_gradient: self.store_gradient,
410 store_divergences: self.store_divergences,
411 store_unconstrained: self.store_unconstrained,
412 check_turning: self.check_turning,
413 };
414
415 let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
416 NutsChain::new(
417 math,
418 hamiltonian,
419 strategy,
420 options,
421 rng,
422 chain,
423 self.stats_options(),
424 )
425 }
426
427 fn hint_num_tune(&self) -> usize {
428 self.num_tune as _
429 }
430
431 fn hint_num_draws(&self) -> usize {
432 self.num_draws as _
433 }
434
435 fn num_chains(&self) -> usize {
436 self.num_chains
437 }
438
439 fn seed(&self) -> u64 {
440 self.seed
441 }
442
443 fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
444 let point = TransformedPointStatsOptions {
446 store_transformed: self.store_unconstrained,
447 };
448 StatOptions {
449 adapt: (),
450 hamiltonian: (),
451 point,
452 }
453 }
454}
455
456pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
457 math: M,
458 settings: DiagGradNutsSettings,
459 start: &[f64],
460 draws: u64,
461 chain: u64,
462 rng: &mut R,
463) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
464 let mut sampler = settings.new_chain(chain, math, rng);
465 sampler.set_position(start)?;
466 Ok((0..draws).map(move |_| sampler.draw()))
467}
468
469#[non_exhaustive]
470#[derive(Clone, Debug)]
471pub struct ChainProgress {
472 pub finished_draws: usize,
473 pub total_draws: usize,
474 pub divergences: usize,
475 pub tuning: bool,
476 pub started: bool,
477 pub latest_num_steps: usize,
478 pub total_num_steps: usize,
479 pub step_size: f64,
480 pub runtime: Duration,
481 pub divergent_draws: Vec<usize>,
482}
483
484impl ChainProgress {
485 fn new(total: usize) -> Self {
486 Self {
487 finished_draws: 0,
488 total_draws: total,
489 divergences: 0,
490 tuning: true,
491 started: false,
492 latest_num_steps: 0,
493 step_size: 0f64,
494 total_num_steps: 0,
495 runtime: Duration::ZERO,
496 divergent_draws: Vec::new(),
497 }
498 }
499
500 fn update(&mut self, stats: &Progress, draw_duration: Duration) {
501 if stats.diverging & !stats.tuning {
502 self.divergences += 1;
503 self.divergent_draws.push(self.finished_draws);
504 }
505 self.finished_draws += 1;
506 self.tuning = stats.tuning;
507
508 self.latest_num_steps = stats.num_steps as usize;
509 self.total_num_steps += stats.num_steps as usize;
510 self.step_size = stats.step_size;
511 self.runtime += draw_duration;
512 }
513}
514
515enum ChainCommand {
516 Resume,
517 Pause,
518}
519
520struct ChainProcess<T>
521where
522 T: TraceStorage,
523{
524 stop_marker: Sender<ChainCommand>,
525 trace: Arc<Mutex<Option<T::ChainStorage>>>,
526 progress: Arc<Mutex<ChainProgress>>,
527}
528
529impl<T: TraceStorage> ChainProcess<T> {
530 fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
531 let finalized_chain_traces = chains
532 .into_iter()
533 .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
534 .map(|chain| chain.finalize())
535 .collect_vec();
536 trace.finalize(finalized_chain_traces)
537 }
538
539 fn progress(&self) -> ChainProgress {
540 self.progress.lock().expect("Poisoned lock").clone()
541 }
542
543 fn resume(&self) -> Result<()> {
544 self.stop_marker.send(ChainCommand::Resume)?;
545 Ok(())
546 }
547
548 fn pause(&self) -> Result<()> {
549 self.stop_marker.send(ChainCommand::Pause)?;
550 Ok(())
551 }
552
553 fn start<'model, M: Model, S: Settings>(
554 model: &'model M,
555 chain_trace: T::ChainStorage,
556 chain_id: u64,
557 seed: u64,
558 settings: &'model S,
559 scope: &ScopeFifo<'model>,
560 results: Sender<Result<()>>,
561 ) -> Result<Self> {
562 let (stop_marker_tx, stop_marker_rx) = channel();
563
564 let mut rng = ChaCha8Rng::seed_from_u64(seed);
565 rng.set_stream(chain_id + 1);
566
567 let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
568 let progress = Arc::new(Mutex::new(ChainProgress::new(
569 settings.hint_num_draws() + settings.hint_num_tune(),
570 )));
571
572 let trace_inner = chain_trace.clone();
573 let progress_inner = progress.clone();
574
575 scope.spawn_fifo(move |_| {
576 let chain_trace = trace_inner;
577 let progress = progress_inner;
578
579 let mut sample = move || {
580 let logp = model
581 .math(&mut rng)
582 .context("Failed to create model density")?;
583 let dim = logp.dim();
584
585 let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
586
587 progress.lock().expect("Poisoned mutex").started = true;
588
589 let mut initval = vec![0f64; dim];
590 let mut error = None;
592 for _ in 0..500 {
593 model
594 .init_position(&mut rng, &mut initval)
595 .context("Failed to generate a new initial position")?;
596 if let Err(err) = sampler.set_position(&initval) {
597 error = Some(err);
598 continue;
599 }
600 error = None;
601 break;
602 }
603
604 if let Some(error) = error {
605 return Err(error.context("All initialization points failed"));
606 }
607
608 let draws = settings.hint_num_tune() + settings.hint_num_draws();
609
610 let mut msg = stop_marker_rx.try_recv();
611 let mut draw = 0;
612 loop {
613 match msg {
614 Err(TryRecvError::Disconnected) => {
616 break;
617 }
618 Err(TryRecvError::Empty) => {}
619 Ok(ChainCommand::Pause) => {
620 msg = stop_marker_rx.recv().map_err(|e| e.into());
621 continue;
622 }
623 Ok(ChainCommand::Resume) => {}
624 }
625
626 let now = Instant::now();
627 let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
629
630 let mut guard = chain_trace
631 .lock()
632 .expect("Could not unlock trace lock. Poisoned mutex");
633
634 let Some(trace_val) = guard.as_mut() else {
635 break;
637 };
638 progress
639 .lock()
640 .expect("Poisoned mutex")
641 .update(&info, now.elapsed());
642
643 let math = sampler.math();
644 let dims = StatsDims::from(math.deref());
645 trace_val.record_sample(
646 settings,
647 stats.get_all(&dims),
648 draw_data.get_all(math.deref()),
649 &info,
650 )?;
651
652 draw += 1;
653 if draw == draws {
654 break;
655 }
656
657 msg = stop_marker_rx.try_recv();
658 }
659 Ok(())
660 };
661
662 let result = sample();
663
664 let _ = results.send(result);
667 drop(results);
668 });
669
670 Ok(Self {
671 trace: chain_trace,
672 stop_marker: stop_marker_tx,
673 progress,
674 })
675 }
676
677 fn flush(&self) -> Result<()> {
678 self.trace
679 .lock()
680 .map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
681 .context("Could not flush trace")?
682 .as_mut()
683 .map(|v| v.flush())
684 .transpose()?;
685 Ok(())
686 }
687}
688
689#[derive(Debug)]
690enum SamplerCommand {
691 Pause,
692 Continue,
693 Progress,
694 Flush,
695 Inspect,
696}
697
698enum SamplerResponse<T: Send + 'static> {
699 Ok(),
700 Progress(Box<[ChainProgress]>),
701 Inspect(T),
702}
703
704pub enum SamplerWaitResult<F: Send + 'static> {
705 Trace(F),
706 Timeout(Sampler<F>),
707 Err(anyhow::Error, Option<F>),
708}
709
710pub struct Sampler<F: Send + 'static> {
711 main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
712 commands: SyncSender<SamplerCommand>,
713 responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
714 results: Receiver<Result<()>>,
715}
716
717pub struct ProgressCallback {
718 pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
719 pub rate: Duration,
720}
721
722impl<F: Send + 'static> Sampler<F> {
723 pub fn new<M, S, C, T>(
724 model: M,
725 settings: S,
726 trace_config: C,
727 num_cores: usize,
728 callback: Option<ProgressCallback>,
729 ) -> Result<Self>
730 where
731 S: Settings,
732 C: StorageConfig<Storage = T>,
733 M: Model,
734 T: TraceStorage<Finalized = F>,
735 {
736 let (commands_tx, commands_rx) = sync_channel(0);
737 let (responses_tx, responses_rx) = sync_channel(0);
738 let (results_tx, results_rx) = channel();
739
740 let main_thread = spawn(move || {
741 let pool = ThreadPoolBuilder::new()
742 .num_threads(num_cores + 1) .thread_name(|i| format!("nutpie-worker-{i}"))
744 .build()
745 .context("Could not start thread pool")?;
746
747 let settings_ref = &settings;
748 let model_ref = &model;
749 let mut callback = callback;
750
751 pool.scope_fifo(move |scope| {
752 let results = results_tx;
753 let mut chains = Vec::with_capacity(settings.num_chains());
754
755 let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
756 rng.set_stream(0);
757
758 let math = model_ref
759 .math(&mut rng)
760 .context("Could not create model density")?;
761 let trace = trace_config
762 .new_trace(settings_ref, &math)
763 .context("Could not create trace object")?;
764 drop(math);
765
766 for chain_id in 0..settings.num_chains() {
767 let chain_trace_val = trace
768 .initialize_trace_for_chain(chain_id as u64)
769 .context("Failed to create trace object")?;
770 let chain = ChainProcess::start(
771 model_ref,
772 chain_trace_val,
773 chain_id as u64,
774 settings.seed(),
775 settings_ref,
776 scope,
777 results.clone(),
778 );
779 chains.push(chain);
780 }
781 drop(results);
782
783 let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
784 if let Some(error) = errors.into_iter().next() {
785 let _ = ChainProcess::finalize_many(trace, chains);
786 return Err(error).context("Could not start chains");
787 }
788
789 let mut main_loop = || {
790 let start_time = Instant::now();
791 let mut pause_start = Instant::now();
792 let mut pause_time = Duration::ZERO;
793
794 let mut progress_rate = Duration::MAX;
795 if let Some(ProgressCallback { callback, rate }) = &mut callback {
796 let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
797 callback(start_time.elapsed(), progress.into());
798 progress_rate = *rate;
799 }
800 let mut last_progress = Instant::now();
801 let mut is_paused = false;
802
803 loop {
804 let timeout = progress_rate.checked_sub(last_progress.elapsed());
805 let timeout = timeout.unwrap_or_else(|| {
806 if let Some(ProgressCallback { callback, .. }) = &mut callback {
807 let progress =
808 chains.iter().map(|chain| chain.progress()).collect_vec();
809 let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
810 if is_paused {
811 elapsed = elapsed.saturating_sub(pause_start.elapsed());
812 }
813 callback(elapsed, progress.into());
814 }
815 last_progress = Instant::now();
816 progress_rate
817 });
818
819 match commands_rx.recv_timeout(timeout) {
821 Ok(SamplerCommand::Pause) => {
822 for chain in chains.iter() {
823 let _ = chain.pause();
826 }
827 if !is_paused {
828 pause_start = Instant::now();
829 }
830 is_paused = true;
831 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
832 anyhow::anyhow!(
833 "Could not send pause response to controller thread: {e}"
834 )
835 })?;
836 }
837 Ok(SamplerCommand::Continue) => {
838 for chain in chains.iter() {
839 let _ = chain.resume();
842 }
843 pause_time += pause_start.elapsed();
844 is_paused = false;
845 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
846 anyhow::anyhow!(
847 "Could not send continue response to controller thread: {e}"
848 )
849 })?;
850 }
851 Ok(SamplerCommand::Progress) => {
852 let progress =
853 chains.iter().map(|chain| chain.progress()).collect_vec();
854 responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
855 anyhow::anyhow!(
856 "Could not send progress response to controller thread: {e}"
857 )
858 })?;
859 }
860 Ok(SamplerCommand::Inspect) => {
861 let traces = chains
862 .iter()
863 .map(|chain| {
864 chain
865 .trace
866 .lock()
867 .expect("Poisoned lock")
868 .as_ref()
869 .map(|v| v.inspect())
870 })
871 .flatten()
872 .collect_vec();
873 let finalized_trace = trace.inspect(traces)?;
874 responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
875 anyhow::anyhow!(
876 "Could not send inspect response to controller thread: {e}"
877 )
878 })?;
879 }
880 Ok(SamplerCommand::Flush) => {
881 for chain in chains.iter() {
882 chain.flush()?;
883 }
884 responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
885 anyhow::anyhow!(
886 "Could not send flush response to controller thread: {e}"
887 )
888 })?;
889 }
890 Err(RecvTimeoutError::Timeout) => {}
891 Err(RecvTimeoutError::Disconnected) => {
892 if let Some(ProgressCallback { callback, .. }) = &mut callback {
893 let progress =
894 chains.iter().map(|chain| chain.progress()).collect_vec();
895 let mut elapsed =
896 start_time.elapsed().saturating_sub(pause_time);
897 if is_paused {
898 elapsed = elapsed.saturating_sub(pause_start.elapsed());
899 }
900 callback(elapsed, progress.into());
901 }
902 return Ok(());
903 }
904 };
905 }
906 };
907 let result: Result<()> = main_loop();
908 let output = ChainProcess::finalize_many(trace, chains)?;
910
911 result?;
912 Ok(output)
913 })
914 });
915
916 Ok(Self {
917 main_thread,
918 commands: commands_tx,
919 responses: responses_rx,
920 results: results_rx,
921 })
922 }
923
924 pub fn pause(&mut self) -> Result<()> {
925 self.commands
926 .send(SamplerCommand::Pause)
927 .context("Could not send pause command to controller thread")?;
928 let response = self
929 .responses
930 .recv()
931 .context("Could not recieve pause response from controller thread")?;
932 let SamplerResponse::Ok() = response else {
933 bail!("Got invalid response from sample controller thread");
934 };
935 Ok(())
936 }
937
938 pub fn resume(&mut self) -> Result<()> {
939 self.commands.send(SamplerCommand::Continue)?;
940 let response = self.responses.recv()?;
941 let SamplerResponse::Ok() = response else {
942 bail!("Got invalid response from sample controller thread");
943 };
944 Ok(())
945 }
946
947 pub fn flush(&mut self) -> Result<()> {
948 self.commands.send(SamplerCommand::Flush)?;
949 let response = self
950 .responses
951 .recv()
952 .context("Could not recieve flush response from controller thread")?;
953 let SamplerResponse::Ok() = response else {
954 bail!("Got invalid response from sample controller thread");
955 };
956 Ok(())
957 }
958
959 pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
960 self.commands.send(SamplerCommand::Inspect)?;
961 let response = self
962 .responses
963 .recv()
964 .context("Could not recieve inspect response from controller thread")?;
965 let SamplerResponse::Inspect(trace) = response else {
966 bail!("Got invalid response from sample controller thread");
967 };
968 Ok(trace)
969 }
970
971 pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
972 drop(self.commands);
973 let result = self.main_thread.join();
974 match result {
975 Err(payload) => std::panic::resume_unwind(payload),
976 Ok(Ok(val)) => Ok(val),
977 Ok(Err(err)) => Err(err),
978 }
979 }
980
981 pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
982 let start = Instant::now();
983 let mut remaining = Some(timeout);
984 while remaining.is_some() {
985 match self.results.recv_timeout(timeout) {
986 Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
987 Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
988 Err(RecvTimeoutError::Disconnected) => match self.abort() {
989 Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
990 Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
991 Err(err) => return SamplerWaitResult::Err(err, None),
992 },
993 Err(RecvTimeoutError::Timeout) => break,
994 }
995 }
996 SamplerWaitResult::Timeout(self)
997 }
998
999 pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
1000 self.commands.send(SamplerCommand::Progress)?;
1001 let response = self.responses.recv()?;
1002 let SamplerResponse::Progress(progress) = response else {
1003 bail!("Got invalid response from sample controller thread");
1004 };
1005 Ok(progress)
1006 }
1007}
1008
1009#[cfg(test)]
1010pub mod test_logps {
1011
1012 use std::collections::HashMap;
1013
1014 use crate::{
1015 Model,
1016 cpu_math::{CpuLogpFunc, CpuMath},
1017 math_base::LogpError,
1018 };
1019 use anyhow::Result;
1020 use nuts_storable::HasDims;
1021 use rand::Rng;
1022 use thiserror::Error;
1023
1024 #[derive(Clone, Debug)]
1025 pub struct NormalLogp {
1026 pub dim: usize,
1027 pub mu: f64,
1028 }
1029
1030 #[derive(Error, Debug)]
1031 pub enum NormalLogpError {}
1032
1033 impl LogpError for NormalLogpError {
1034 fn is_recoverable(&self) -> bool {
1035 false
1036 }
1037 }
1038
1039 impl HasDims for &NormalLogp {
1040 fn dim_sizes(&self) -> HashMap<String, u64> {
1041 vec![
1042 ("unconstrained_parameter".to_string(), self.dim as u64),
1043 ("dim".to_string(), self.dim as u64),
1044 ]
1045 .into_iter()
1046 .collect()
1047 }
1048 }
1049
1050 impl CpuLogpFunc for &NormalLogp {
1051 type LogpError = NormalLogpError;
1052 type FlowParameters = ();
1053 type ExpandedVector = Vec<f64>;
1054
1055 fn dim(&self) -> usize {
1056 self.dim
1057 }
1058
1059 fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NormalLogpError> {
1060 let n = position.len();
1061 assert!(gradient.len() == n);
1062
1063 let mut logp = 0f64;
1064 for (p, g) in position.iter().zip(gradient.iter_mut()) {
1065 let val = self.mu - p;
1066 logp -= val * val / 2.;
1067 *g = val;
1068 }
1069
1070 Ok(logp)
1071 }
1072
1073 fn expand_vector<R>(
1074 &mut self,
1075 _rng: &mut R,
1076 array: &[f64],
1077 ) -> std::result::Result<Self::ExpandedVector, crate::cpu_math::CpuMathError>
1078 where
1079 R: rand::Rng + ?Sized,
1080 {
1081 Ok(array.to_vec())
1082 }
1083
1084 fn inv_transform_normalize(
1085 &mut self,
1086 _params: &Self::FlowParameters,
1087 _untransformed_position: &[f64],
1088 _untransofrmed_gradient: &[f64],
1089 _transformed_position: &mut [f64],
1090 _transformed_gradient: &mut [f64],
1091 ) -> std::result::Result<f64, Self::LogpError> {
1092 unimplemented!()
1093 }
1094
1095 fn init_from_untransformed_position(
1096 &mut self,
1097 _params: &Self::FlowParameters,
1098 _untransformed_position: &[f64],
1099 _untransformed_gradient: &mut [f64],
1100 _transformed_position: &mut [f64],
1101 _transformed_gradient: &mut [f64],
1102 ) -> std::result::Result<(f64, f64), Self::LogpError> {
1103 unimplemented!()
1104 }
1105
1106 fn init_from_transformed_position(
1107 &mut self,
1108 _params: &Self::FlowParameters,
1109 _untransformed_position: &mut [f64],
1110 _untransformed_gradient: &mut [f64],
1111 _transformed_position: &[f64],
1112 _transformed_gradient: &mut [f64],
1113 ) -> std::result::Result<(f64, f64), Self::LogpError> {
1114 unimplemented!()
1115 }
1116
1117 fn update_transformation<'b, R: rand::Rng + ?Sized>(
1118 &'b mut self,
1119 _rng: &mut R,
1120 _untransformed_positions: impl Iterator<Item = &'b [f64]>,
1121 _untransformed_gradients: impl Iterator<Item = &'b [f64]>,
1122 _untransformed_logp: impl Iterator<Item = &'b f64>,
1123 _params: &'b mut Self::FlowParameters,
1124 ) -> std::result::Result<(), Self::LogpError> {
1125 unimplemented!()
1126 }
1127
1128 fn new_transformation<R: rand::Rng + ?Sized>(
1129 &mut self,
1130 _rng: &mut R,
1131 _untransformed_position: &[f64],
1132 _untransfogmed_gradient: &[f64],
1133 _chain: u64,
1134 ) -> std::result::Result<Self::FlowParameters, Self::LogpError> {
1135 unimplemented!()
1136 }
1137
1138 fn transformation_id(
1139 &self,
1140 _params: &Self::FlowParameters,
1141 ) -> std::result::Result<i64, Self::LogpError> {
1142 unimplemented!()
1143 }
1144 }
1145
1146 pub struct CpuModel<F> {
1147 logp: F,
1148 }
1149
1150 impl<F> CpuModel<F> {
1151 pub fn new(logp: F) -> Self {
1152 Self { logp }
1153 }
1154 }
1155
1156 impl<F> Model for CpuModel<F>
1157 where
1158 F: Send + Sync + 'static,
1159 for<'a> &'a F: CpuLogpFunc,
1160 {
1161 type Math<'model> = CpuMath<&'model F>;
1162
1163 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
1164 Ok(CpuMath::new(&self.logp))
1165 }
1166
1167 fn init_position<R: rand::prelude::Rng + ?Sized>(
1168 &self,
1169 _rng: &mut R,
1170 position: &mut [f64],
1171 ) -> Result<()> {
1172 position.iter_mut().for_each(|x| *x = 0.);
1173 Ok(())
1174 }
1175 }
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180 use std::{
1181 sync::Arc,
1182 time::{Duration, Instant},
1183 };
1184
1185 use super::test_logps::NormalLogp;
1186 use crate::{
1187 Chain, DiagGradNutsSettings, Sampler, ZarrConfig,
1188 cpu_math::CpuMath,
1189 sample_sequentially,
1190 sampler::{Settings, test_logps::CpuModel},
1191 };
1192
1193 use anyhow::Result;
1194 use itertools::Itertools;
1195 use pretty_assertions::assert_eq;
1196 use rand::{SeedableRng, rngs::StdRng};
1197 use zarrs::storage::store::MemoryStore;
1198
1199 #[test]
1200 fn sample_chain() -> Result<()> {
1201 let logp = NormalLogp { dim: 10, mu: 0.1 };
1202 let math = CpuMath::new(&logp);
1203 let settings = DiagGradNutsSettings {
1204 num_tune: 100,
1205 num_draws: 100,
1206 ..Default::default()
1207 };
1208 let start = vec![0.2; 10];
1209
1210 let mut rng = StdRng::seed_from_u64(42);
1211
1212 let mut chain = settings.new_chain(0, math, &mut rng);
1213
1214 let (_draw, info) = chain.draw()?;
1215 assert!(info.tuning);
1216 assert_eq!(info.draw, 0);
1217
1218 let math = CpuMath::new(&logp);
1219 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1220 let mut draws = chain.collect_vec();
1221 assert_eq!(draws.len(), 200);
1222
1223 let draw0 = draws.remove(100).unwrap();
1224 let (vals, stats) = draw0;
1225 assert_eq!(vals.len(), 10);
1226 assert_eq!(stats.chain, 1);
1227 assert_eq!(stats.draw, 100);
1228 Ok(())
1229 }
1230
1231 #[test]
1232 fn sample_parallel() -> Result<()> {
1233 let logp = NormalLogp { dim: 100, mu: 0.1 };
1234 let settings = DiagGradNutsSettings {
1235 num_tune: 100,
1236 num_draws: 100,
1237 seed: 10,
1238 ..Default::default()
1239 };
1240
1241 let model = CpuModel::new(logp.clone());
1242 let store = MemoryStore::new();
1243
1244 let zarr_config = ZarrConfig::new(Arc::new(store));
1245 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1246 sampler.pause()?;
1247 sampler.pause()?;
1248 sampler.resume()?;
1250 let (ok, _) = sampler.abort()?;
1251 if let Some(err) = ok {
1252 Err(err)?;
1253 }
1254
1255 let store = MemoryStore::new();
1256 let zarr_config = ZarrConfig::new(Arc::new(store));
1257 let model = CpuModel::new(logp.clone());
1258 let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1259 sampler.pause()?;
1260 if let (Some(err), _) = sampler.abort()? {
1261 Err(err)?;
1262 }
1263
1264 let store = MemoryStore::new();
1265 let zarr_config = ZarrConfig::new(Arc::new(store));
1266 let model = CpuModel::new(logp.clone());
1267 let start = Instant::now();
1268 let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
1269
1270 let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
1271 super::SamplerWaitResult::Trace(_) => {
1272 dbg!(start.elapsed());
1273 panic!("finished");
1274 }
1275 super::SamplerWaitResult::Timeout(sampler) => sampler,
1276 super::SamplerWaitResult::Err(_, _) => {
1277 panic!("error")
1278 }
1279 };
1280
1281 for _ in 0..30 {
1282 sampler.progress()?;
1283 }
1284
1285 match sampler.wait_timeout(Duration::from_secs(1)) {
1286 super::SamplerWaitResult::Trace(_) => {
1287 dbg!(start.elapsed());
1288 }
1289 super::SamplerWaitResult::Timeout(_) => {
1290 panic!("timeout")
1291 }
1292 super::SamplerWaitResult::Err(err, _) => Err(err)?,
1293 };
1294
1295 Ok(())
1296 }
1297
1298 #[test]
1299 fn sample_seq() {
1300 let logp = NormalLogp { dim: 10, mu: 0.1 };
1301 let math = CpuMath::new(&logp);
1302 let settings = DiagGradNutsSettings {
1303 num_tune: 100,
1304 num_draws: 100,
1305 ..Default::default()
1306 };
1307 let start = vec![0.2; 10];
1308
1309 let mut rng = StdRng::seed_from_u64(42);
1310
1311 let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
1312 let mut draws = chain.collect_vec();
1313 assert_eq!(draws.len(), 200);
1314
1315 let draw0 = draws.remove(100).unwrap();
1316 let (vals, stats) = draw0;
1317 assert_eq!(vals.len(), 10);
1318 assert_eq!(stats.chain, 1);
1319 assert_eq!(stats.draw, 100);
1320 }
1321}