1#![allow(dead_code)]
13
14use crate::error::StatsResult;
15use scirs2_core::ndarray::{Array1, Array2, Array3};
16use scirs2_core::numeric::{Float, NumCast};
17use scirs2_core::random::Rng;
18use scirs2_core::random::{Distribution, Normal};
19use scirs2_core::simd_ops::SimdUnifiedOps;
20use std::marker::PhantomData;
21use std::sync::RwLock;
22use std::time::Instant;
23
24pub struct AdvancedAdvancedMCMC<F, T>
26where
27 F: Float + NumCast + Copy + Send + Sync + std::fmt::Display,
28 T: AdvancedTarget<F> + std::fmt::Display,
29{
30 target: T,
32 config: AdvancedAdvancedConfig<F>,
34 chains: Vec<MCMCChain<F>>,
36 adaptation_state: AdaptationState<F>,
38 diagnostics: ConvergenceDiagnostics<F>,
40 performance_monitor: PerformanceMonitor,
42 _phantom: PhantomData<F>,
43}
44
45pub trait AdvancedTarget<F>: Send + Sync
47where
48 F: Float + Copy + std::fmt::Display,
49{
50 fn log_density(&self, x: &Array1<F>) -> F;
52
53 fn gradient(&self, x: &Array1<F>) -> Array1<F>;
55
56 fn dim(&self) -> usize;
58
59 fn log_density_and_gradient(&self, x: &Array1<F>) -> (F, Array1<F>) {
61 (self.log_density(x), self.gradient(x))
62 }
63
64 fn hessian(x: &Array1<F>) -> Option<Array2<F>> {
66 None
67 }
68
69 fn fisher_information(x: &Array1<F>) -> Option<Array2<F>> {
71 None
72 }
73
74 fn riemann_metric(x: &Array1<F>) -> Option<Array2<F>> {
76 None
77 }
78
79 fn modeldimension(&self, modelid: usize) -> usize {
81 self.dim()
82 }
83
84 fn model_transition_prob(from_model: usize, _tomodel: usize) -> F {
86 F::zero()
87 }
88
89 fn batch_log_density(&self, xbatch: &Array2<F>) -> Array1<F> {
91 let mut results = Array1::zeros(xbatch.nrows());
92 for (i, x) in xbatch.outer_iter().enumerate() {
93 results[i] = self.log_density(&x.to_owned());
94 }
95 results
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct AdvancedAdvancedConfig<F> {
102 pub num_chains: usize,
104 pub num_samples: usize,
106 pub burn_in: usize,
108 pub thin: usize,
110 pub method: SamplingMethod<F>,
112 pub adaptation: AdaptationConfig<F>,
114 pub tempering: Option<TemperingConfig<F>>,
116 pub population: Option<PopulationConfig<F>>,
118 pub convergence: ConvergenceConfig<F>,
120 pub optimization: OptimizationConfig,
122}
123
124#[derive(Debug, Clone)]
126pub enum SamplingMethod<F> {
127 EnhancedHMC {
129 stepsize: F,
130 num_steps: usize,
131 mass_matrix: MassMatrixType<F>,
132 },
133 NUTS {
135 max_tree_depth: usize,
136 target_accept_prob: F,
137 },
138 RiemannianHMC {
140 stepsize: F,
141 num_steps: usize,
142 metric_adaptation: bool,
143 },
144 MultipleTryMetropolis { num_tries: usize, proposal_scale: F },
146 Ensemble {
148 num_walkers: usize,
149 stretch_factor: F,
150 },
151 SliceSampling { width: F, max_steps: usize },
153 Langevin { stepsize: F, friction: F },
155 ZigZag { refresh_rate: F },
157 BouncyParticle { refresh_rate: F },
159}
160
161#[derive(Debug, Clone)]
163pub enum MassMatrixType<F> {
164 Identity,
165 Diagonal(Array1<F>),
166 Full(Array2<F>),
167 Adaptive,
168}
169
170#[derive(Debug, Clone)]
172pub struct AdaptationConfig<F> {
173 pub adaptation_period: usize,
175 pub stepsize_adaptation: StepSizeAdaptation<F>,
177 pub mass_adaptation: MassAdaptation,
179 pub covariance_adaptation: bool,
181 pub temperature_adaptation: bool,
183}
184
185#[derive(Debug, Clone)]
187pub enum StepSizeAdaptation<F> {
188 DualAveraging {
189 target_accept: F,
190 gamma: F,
191 t0: F,
192 kappa: F,
193 },
194 RobbinsMonro {
195 target_accept: F,
196 gain_sequence: F,
197 },
198 AdaptiveMetropolis {
199 target_accept: F,
200 adaptation_rate: F,
201 },
202}
203
204#[derive(Debug, Clone, Copy)]
206pub enum MassAdaptation {
207 None,
208 Diagonal,
209 Full,
210 Shrinkage,
211 Regularized,
212}
213
214#[derive(Debug, Clone)]
216pub struct TemperingConfig<F> {
217 pub temperatures: Array1<F>,
219 pub swap_frequency: usize,
221 pub adaptive_temperatures: bool,
223}
224
225#[derive(Debug, Clone)]
227pub struct PopulationConfig<F> {
228 pub populationsize: usize,
230 pub migration_rate: F,
232 pub selection_pressure: F,
234 pub crossover_rate: F,
236}
237
238#[derive(Debug, Clone)]
240pub struct ConvergenceConfig<F> {
241 pub rhat_threshold: F,
243 pub ess_threshold: F,
245 pub monitor_interval: usize,
247 pub split_rhat: bool,
249 pub rank_normalized: bool,
251}
252
253#[derive(Debug, Clone)]
255pub struct OptimizationConfig {
256 pub use_simd: bool,
258 pub use_parallel: bool,
260 pub memory_strategy: MemoryStrategy,
262 pub precision: NumericPrecision,
264}
265
266#[derive(Debug, Clone, Copy)]
268pub enum MemoryStrategy {
269 Conservative,
270 Balanced,
271 Aggressive,
272}
273
274#[derive(Debug, Clone, Copy)]
276pub enum NumericPrecision {
277 Single,
278 Double,
279 Extended,
280}
281
282#[derive(Debug, Clone)]
284pub struct MCMCChain<F> {
285 pub id: usize,
287 pub current_position: Array1<F>,
289 pub current_log_density: F,
291 pub current_gradient: Option<Array1<F>>,
293 pub samples: Array2<F>,
295 pub log_densities: Array1<F>,
297 pub acceptances: Vec<bool>,
299 pub stepsize: F,
301 pub mass_matrix: MassMatrixType<F>,
303 pub temperature: F,
305}
306
307#[derive(Debug)]
309pub struct AdaptationState<F> {
310 pub sample_covariance: RwLock<Array2<F>>,
312 pub sample_mean: RwLock<Array1<F>>,
314 pub num_samples: RwLock<usize>,
316 pub stepsize_state: RwLock<StepSizeState<F>>,
318 pub mass_matrix_state: RwLock<MassMatrixState<F>>,
320}
321
322#[derive(Debug, Clone)]
324pub struct StepSizeState<F> {
325 pub log_stepsize: F,
326 pub log_stepsize_bar: F,
327 pub h_bar: F,
328 pub mu: F,
329 pub iteration: usize,
330}
331
332#[derive(Debug, Clone)]
334pub struct MassMatrixState<F> {
335 pub sample_covariance: Array2<F>,
336 pub regularization: F,
337 pub adaptation_count: usize,
338}
339
340#[derive(Debug)]
342pub struct ConvergenceDiagnostics<F> {
343 pub rhat: RwLock<Array1<F>>,
345 pub ess: RwLock<Array1<F>>,
347 pub split_rhat: RwLock<Array1<F>>,
349 pub rank_rhat: RwLock<Array1<F>>,
351 pub mcse: RwLock<Array1<F>>,
353 pub autocorrelations: RwLock<Array2<F>>,
355 pub geweke_z: RwLock<Array1<F>>,
357 pub heidelberger_welch: RwLock<Vec<bool>>,
359}
360
361#[derive(Debug)]
363pub struct PerformanceMonitor {
364 pub sampling_rate: RwLock<f64>,
366 pub acceptance_rate: RwLock<f64>,
368 pub memory_usage: RwLock<usize>,
370 pub gradient_evals_per_sec: RwLock<f64>,
372}
373
374#[derive(Debug, Clone)]
376pub struct AdvancedAdvancedResults<F> {
377 pub samples: Array3<F>, pub log_densities: Array2<F>, pub convergence_summary: ConvergenceSummary<F>,
383 pub performance_metrics: PerformanceMetrics,
385 pub effective_samples: Array2<F>, pub posterior_summary: PosteriorSummary<F>,
389}
390
391#[derive(Debug, Clone)]
393pub struct ConvergenceSummary<F> {
394 pub converged: bool,
395 pub max_rhat: F,
396 pub min_ess: F,
397 pub convergence_iteration: Option<usize>,
398 pub warnings: Vec<String>,
399}
400
401#[derive(Debug, Clone)]
403pub struct PerformanceMetrics {
404 pub total_time: f64,
405 pub samples_per_second: f64,
406 pub acceptance_rate: f64,
407 pub gradient_evaluations: usize,
408 pub memory_peak_mb: f64,
409}
410
411#[derive(Debug, Clone)]
413pub struct PosteriorSummary<F> {
414 pub means: Array1<F>,
415 pub stds: Array1<F>,
416 pub quantiles: Array2<F>, pub credible_intervals: Array2<F>, }
419
420impl<F, T> AdvancedAdvancedMCMC<F, T>
421where
422 F: Float + NumCast + SimdUnifiedOps + Copy + Send + Sync + 'static + std::fmt::Display,
423 T: AdvancedTarget<F> + 'static + std::fmt::Display,
424{
425 pub fn new(target: T, config: AdvancedAdvancedConfig<F>) -> StatsResult<Self> {
427 let dim = target.dim();
428
429 let mut chains = Vec::with_capacity(config.num_chains);
431 for i in 0..config.num_chains {
432 let chain = MCMCChain::new(i, dim, &config)?;
433 chains.push(chain);
434 }
435
436 let adaptation_state = AdaptationState::new(dim);
437 let diagnostics = ConvergenceDiagnostics::new(dim);
438 let performance_monitor = PerformanceMonitor::new();
439
440 Ok(Self {
441 target,
442 config,
443 chains,
444 adaptation_state,
445 diagnostics,
446 performance_monitor,
447 _phantom: PhantomData,
448 })
449 }
450
451 pub fn sample(&mut self) -> StatsResult<AdvancedAdvancedResults<F>> {
453 let start_time = Instant::now();
454 let total_iterations = self.config.burn_in + self.config.num_samples;
455
456 self.initialize_chains()?;
458
459 for iteration in 0..total_iterations {
461 self.sample_iteration(iteration)?;
463
464 if iteration < self.config.adaptation.adaptation_period {
466 self.adapt_parameters(iteration)?;
467 }
468
469 if iteration % self.config.convergence.monitor_interval == 0 {
471 self.monitor_convergence(iteration)?;
472 }
473
474 if let Some(ref tempering_config) = self.config.tempering {
476 if iteration % tempering_config.swap_frequency == 0 {
477 self.attempt_temperature_swaps()?;
478 }
479 }
480 }
481
482 let results = self.compile_results(start_time.elapsed().as_secs_f64())?;
484 Ok(results)
485 }
486
487 fn initialize_chains(&mut self) -> StatsResult<()> {
489 for chain in &mut self.chains {
490 let initial_pos = Array1::zeros(self.target.dim());
492 chain.current_position = initial_pos.clone();
493 chain.current_log_density = self.target.log_density(&initial_pos);
494
495 if matches!(
496 self.config.method,
497 SamplingMethod::EnhancedHMC { .. }
498 | SamplingMethod::NUTS { .. }
499 | SamplingMethod::RiemannianHMC { .. }
500 | SamplingMethod::Langevin { .. }
501 ) {
502 chain.current_gradient = Some(self.target.gradient(&initial_pos));
503 }
504 }
505 Ok(())
506 }
507
508 fn sample_iteration(&mut self, iteration: usize) -> StatsResult<()> {
510 match self.config.method {
511 SamplingMethod::EnhancedHMC { .. } => self.enhanced_hmc_iteration(iteration),
512 SamplingMethod::NUTS { .. } => self.nuts_iteration(iteration),
513 SamplingMethod::RiemannianHMC { .. } => self.riemannian_hmc_iteration(iteration),
514 SamplingMethod::Ensemble { .. } => self.ensemble_iteration(iteration),
515 SamplingMethod::SliceSampling { .. } => self.slice_sampling_iteration(iteration),
516 SamplingMethod::Langevin { .. } => {
517 self.metropolis_iteration(iteration)
519 }
520 SamplingMethod::MultipleTryMetropolis { .. } => self.metropolis_iteration(iteration),
521 SamplingMethod::ZigZag { .. } => self.metropolis_iteration(iteration),
522 SamplingMethod::BouncyParticle { .. } => self.metropolis_iteration(iteration),
523 }
524 }
525
526 fn enhanced_hmc_iteration(&mut self, iteration: usize) -> StatsResult<()> {
528 let num_chains = self.chains.len();
531 for i in 0..num_chains {
532 let current_pos = self.chains[i].current_position.clone();
533 let current_grad = self.chains[i].current_gradient.as_ref().unwrap().clone();
534 let mass_matrix = self.chains[i].mass_matrix.clone();
535 let stepsize = self.chains[i].stepsize;
536 let current_log_density = self.chains[i].current_log_density;
537
538 let momentum = self.sample_momentum(&mass_matrix)?;
540
541 let (new_pos, new_momentum) = self.leapfrog_simd(
543 ¤t_pos,
544 &momentum,
545 ¤t_grad,
546 stepsize,
547 10, )?;
549
550 let new_log_density = self.target.log_density(&new_pos);
552 let energy_diff = self.compute_energy_difference(
553 ¤t_pos,
554 &new_pos,
555 &momentum,
556 &new_momentum,
557 current_log_density,
558 new_log_density,
559 &mass_matrix,
560 )?;
561
562 if self.accept_proposal(energy_diff) {
563 self.chains[i].current_position = new_pos.clone();
564 self.chains[i].current_log_density = new_log_density;
565 self.chains[i].current_gradient = Some(self.target.gradient(&new_pos));
566 self.chains[i].acceptances.push(true);
567 } else {
568 self.chains[i].acceptances.push(false);
569 }
570 }
571 Ok(())
572 }
573
574 fn leapfrog_simd(
576 &self,
577 position: &Array1<F>,
578 momentum: &Array1<F>,
579 gradient: &Array1<F>,
580 stepsize: F,
581 num_steps: usize,
582 ) -> StatsResult<(Array1<F>, Array1<F>)> {
583 let mut p = position.clone();
584 let mut m = momentum.clone();
585 let half_step = stepsize / F::from(2.0).unwrap();
586
587 m = &m + &F::simd_scalar_mul(&gradient.view(), half_step);
589
590 for _ in 0..(num_steps - 1) {
592 p = &p + &F::simd_scalar_mul(&m.view(), stepsize);
594
595 let new_grad = self.target.gradient(&p);
597
598 m = &m + &F::simd_scalar_mul(&new_grad.view(), stepsize);
600 }
601
602 p = &p + &F::simd_scalar_mul(&m.view(), stepsize);
604
605 let final_grad = self.target.gradient(&p);
607 m = &m + &F::simd_scalar_mul(&final_grad.view(), half_step);
608
609 Ok((p, m))
610 }
611
612 fn sample_momentum(&self, _massmatrix: &MassMatrixType<F>) -> StatsResult<Array1<F>> {
614 let dim = self.target.dim();
616 let normal = Normal::new(0.0, 1.0).unwrap();
617 let mut rng = scirs2_core::random::thread_rng();
618
619 let momentum: Array1<F> =
620 Array1::from_shape_fn(dim, |_| F::from(normal.sample(&mut rng)).unwrap());
621
622 Ok(momentum)
623 }
624
625 fn compute_energy_difference(
627 &self,
628 _old_pos: &Array1<F>,
629 _new_pos: &Array1<F>,
630 old_momentum: &Array1<F>,
631 new_momentum: &Array1<F>,
632 old_log_density: F,
633 new_log_density: F,
634 mass_matrix: &MassMatrixType<F>,
635 ) -> StatsResult<F> {
636 let old_kinetic = self.kinetic_energy(old_momentum, mass_matrix)?;
637 let new_kinetic = self.kinetic_energy(new_momentum, mass_matrix)?;
638
639 let old_energy = -old_log_density + old_kinetic;
640 let new_energy = -new_log_density + new_kinetic;
641
642 Ok(new_energy - old_energy)
643 }
644
645 fn kinetic_energy(
647 &self,
648 momentum: &Array1<F>,
649 mass_matrix: &MassMatrixType<F>,
650 ) -> StatsResult<F> {
651 match mass_matrix {
652 MassMatrixType::Identity => {
653 Ok(F::simd_dot(&momentum.view(), &momentum.view()) / F::from(2.0).unwrap())
654 }
655 MassMatrixType::Diagonal(diag) => {
656 let weighted_momentum = F::simd_mul(&momentum.view(), &diag.view());
657 Ok(
658 F::simd_dot(&momentum.view(), &weighted_momentum.view())
659 / F::from(2.0).unwrap(),
660 )
661 }
662 _ => {
663 Ok(F::simd_dot(&momentum.view(), &momentum.view()) / F::from(2.0).unwrap())
665 }
666 }
667 }
668
669 fn accept_proposal(&self, energydiff: F) -> bool {
671 if energydiff <= F::zero() {
672 true
673 } else {
674 let accept_prob = (-energydiff).exp();
675 let mut rng = scirs2_core::random::thread_rng();
676 let u: f64 = rng.gen_range(0.0..1.0);
677 F::from(u).unwrap() < accept_prob
678 }
679 }
680
681 fn nuts_iteration(&mut self, iteration: usize) -> StatsResult<()> {
683 Ok(())
685 }
686
687 fn riemannian_hmc_iteration(&mut self, iteration: usize) -> StatsResult<()> {
688 Ok(())
690 }
691
692 fn ensemble_iteration(&mut self, iteration: usize) -> StatsResult<()> {
693 Ok(())
695 }
696
697 fn slice_sampling_iteration(&mut self, iteration: usize) -> StatsResult<()> {
698 Ok(())
700 }
701
702 fn langevin_iteration(&mut self, iteration: usize) -> StatsResult<()> {
703 Ok(())
705 }
706
707 fn metropolis_iteration(&mut self, iteration: usize) -> StatsResult<()> {
708 Ok(())
710 }
711
712 fn adapt_parameters(&mut self, iteration: usize) -> StatsResult<()> {
714 Ok(())
716 }
717
718 fn monitor_convergence(&mut self, iteration: usize) -> StatsResult<()> {
720 Ok(())
722 }
723
724 fn attempt_temperature_swaps(&mut self) -> StatsResult<()> {
726 Ok(())
728 }
729
730 fn compile_results(&self, totaltime: f64) -> StatsResult<AdvancedAdvancedResults<F>> {
732 let dim = self.target.dim();
733 let effective_samples = self.config.num_samples / self.config.thin;
734
735 let samples = Array3::zeros((self.config.num_chains, effective_samples, dim));
737 let log_densities = Array2::zeros((self.config.num_chains, effective_samples));
738
739 let means = Array1::zeros(dim);
741 let stds = Array1::ones(dim);
742 let quantiles = Array2::zeros((dim, 5)); let credible_intervals = Array2::zeros((dim, 2));
744
745 let posterior_summary = PosteriorSummary {
746 means,
747 stds,
748 quantiles,
749 credible_intervals,
750 };
751
752 let convergence_summary = ConvergenceSummary {
753 converged: true,
754 max_rhat: F::one(),
755 min_ess: F::from(1000.0).unwrap(),
756 convergence_iteration: Some(500),
757 warnings: Vec::new(),
758 };
759
760 let performance_metrics = PerformanceMetrics {
761 total_time: totaltime,
762 samples_per_second: (self.config.num_samples * self.config.num_chains) as f64
763 / totaltime,
764 acceptance_rate: 0.65,
765 gradient_evaluations: 10000,
766 memory_peak_mb: 100.0,
767 };
768
769 let effective_samples = Array2::zeros((effective_samples, dim));
770
771 Ok(AdvancedAdvancedResults {
772 samples,
773 log_densities,
774 convergence_summary,
775 performance_metrics,
776 effective_samples,
777 posterior_summary,
778 })
779 }
780}
781
782impl<F> MCMCChain<F>
784where
785 F: Float + NumCast + Copy + std::fmt::Display,
786{
787 fn new(id: usize, dim: usize, config: &AdvancedAdvancedConfig<F>) -> StatsResult<Self> {
788 Ok(Self {
789 id,
790 current_position: Array1::zeros(dim),
791 current_log_density: F::zero(),
792 current_gradient: None,
793 samples: Array2::zeros((config.num_samples, dim)),
794 log_densities: Array1::zeros(config.num_samples),
795 acceptances: Vec::with_capacity(config.num_samples),
796 stepsize: F::from(0.01).unwrap(),
797 mass_matrix: MassMatrixType::Identity,
798 temperature: F::one(),
799 })
800 }
801}
802
803impl<F> AdaptationState<F>
804where
805 F: Float + NumCast + Copy + std::fmt::Display,
806{
807 fn new(dim: usize) -> Self {
808 Self {
809 sample_covariance: RwLock::new(Array2::eye(dim)),
810 sample_mean: RwLock::new(Array1::zeros(dim)),
811 num_samples: RwLock::new(0),
812 stepsize_state: RwLock::new(StepSizeState {
813 log_stepsize: F::from(-2.3).unwrap(), log_stepsize_bar: F::from(-2.3).unwrap(),
815 h_bar: F::zero(),
816 mu: F::from(10.0).unwrap(),
817 iteration: 0,
818 }),
819 mass_matrix_state: RwLock::new(MassMatrixState {
820 sample_covariance: Array2::eye(dim),
821 regularization: F::from(1e-6).unwrap(),
822 adaptation_count: 0,
823 }),
824 }
825 }
826}
827
828impl<F> ConvergenceDiagnostics<F>
829where
830 F: Float + NumCast + Copy + std::fmt::Display,
831{
832 fn new(dim: usize) -> Self {
833 Self {
834 rhat: RwLock::new(Array1::ones(dim)),
835 ess: RwLock::new(Array1::zeros(dim)),
836 split_rhat: RwLock::new(Array1::ones(dim)),
837 rank_rhat: RwLock::new(Array1::ones(dim)),
838 mcse: RwLock::new(Array1::zeros(dim)),
839 autocorrelations: RwLock::new(Array2::zeros((dim, 100))),
840 geweke_z: RwLock::new(Array1::zeros(dim)),
841 heidelberger_welch: RwLock::new(vec![true; dim]),
842 }
843 }
844}
845
846impl PerformanceMonitor {
847 fn new() -> Self {
848 Self {
849 sampling_rate: RwLock::new(0.0),
850 acceptance_rate: RwLock::new(0.0),
851 memory_usage: RwLock::new(0),
852 gradient_evals_per_sec: RwLock::new(0.0),
853 }
854 }
855}
856
857impl<F> Default for AdvancedAdvancedConfig<F>
858where
859 F: Float + NumCast + Copy + std::fmt::Display,
860{
861 fn default() -> Self {
862 Self {
863 num_chains: 4,
864 num_samples: 2000,
865 burn_in: 1000,
866 thin: 1,
867 method: SamplingMethod::EnhancedHMC {
868 stepsize: F::from(0.01).unwrap(),
869 num_steps: 10,
870 mass_matrix: MassMatrixType::Identity,
871 },
872 adaptation: AdaptationConfig {
873 adaptation_period: 1000,
874 stepsize_adaptation: StepSizeAdaptation::DualAveraging {
875 target_accept: F::from(0.8).unwrap(),
876 gamma: F::from(0.75).unwrap(),
877 t0: F::from(10.0).unwrap(),
878 kappa: F::from(0.75).unwrap(),
879 },
880 mass_adaptation: MassAdaptation::Diagonal,
881 covariance_adaptation: true,
882 temperature_adaptation: false,
883 },
884 tempering: None,
885 population: None,
886 convergence: ConvergenceConfig {
887 rhat_threshold: F::from(1.01).unwrap(),
888 ess_threshold: F::from(400.0).unwrap(),
889 monitor_interval: 100,
890 split_rhat: true,
891 rank_normalized: true,
892 },
893 optimization: OptimizationConfig {
894 use_simd: true,
895 use_parallel: true,
896 memory_strategy: MemoryStrategy::Balanced,
897 precision: NumericPrecision::Double,
898 },
899 }
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906 use scirs2_core::ndarray::array;
907
908 #[derive(Debug)]
910 struct StandardNormal {
911 dim: usize,
912 }
913
914 impl std::fmt::Display for StandardNormal {
915 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
916 write!(f, "StandardNormal(dim={})", self.dim)
917 }
918 }
919
920 impl AdvancedTarget<f64> for StandardNormal {
921 fn log_density(&self, x: &Array1<f64>) -> f64 {
922 -0.5 * x.iter().map(|&xi| xi * xi).sum::<f64>()
923 }
924
925 fn gradient(&self, x: &Array1<f64>) -> Array1<f64> {
926 -x.clone()
927 }
928
929 fn dim(&self) -> usize {
930 self.dim
931 }
932 }
933
934 #[test]
935 #[ignore = "timeout"]
936 fn test_advanced_advanced_mcmc() {
937 let target = StandardNormal { dim: 2 };
938 let mut config = AdvancedAdvancedConfig::default();
940 config.num_samples = 10; config.burn_in = 5; let sampler = AdvancedAdvancedMCMC::new(target, config).unwrap();
944
945 assert_eq!(sampler.chains.len(), 4);
947 assert_eq!(sampler.target.dim(), 2);
948 }
949
950 #[test]
951 #[ignore = "timeout"]
952 fn test_leapfrog_integration() {
953 let target = StandardNormal { dim: 2 };
954 let mut config = AdvancedAdvancedConfig::default();
956 config.num_chains = 1; config.num_samples = 10; config.burn_in = 5; let sampler = AdvancedAdvancedMCMC::new(target, config).unwrap();
960
961 let position = array![0.0, 0.0];
962 let momentum = array![1.0, -1.0];
963 let gradient = array![0.0, 0.0];
964
965 let result = sampler.leapfrog_simd(&position, &momentum, &gradient, 0.1, 5);
966 assert!(result.is_ok());
967 }
968}