1use crate::error::{StatsError, StatsResult};
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
14use scirs2_core::numeric::{Float, NumCast, One, Zero};
15use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
16use std::collections::HashMap;
17use std::marker::PhantomData;
18
19#[derive(Debug, Clone)]
21pub struct BayesianModelComparison<F> {
22 pub models: Vec<BayesianModel<F>>,
24 pub criteria: Vec<ModelSelectionCriterion>,
26 pub cv_config: CrossValidationConfig,
28 pub parallel_config: ParallelConfig,
30}
31
32#[derive(Debug, Clone)]
34pub struct BayesianModel<F> {
35 pub id: String,
37 pub model_type: ModelType,
39 pub prior: AdvancedPrior<F>,
41 pub likelihood: LikelihoodType,
43 pub complexity: f64,
45}
46
47#[derive(Debug, Clone)]
49pub enum AdvancedPrior<F> {
50 Conjugate { parameters: HashMap<String, F> },
52 Hierarchical { levels: Vec<PriorLevel<F>> },
54 Mixture {
56 components: Vec<PriorComponent<F>>,
57 weights: Array1<F>,
58 },
59 Sparse {
61 sparsity_type: SparsityType,
62 sparsity_params: HashMap<String, F>,
63 },
64 NonParametric {
66 process_type: NonParametricProcess,
67 concentration: F,
68 },
69}
70
71#[derive(Debug, Clone)]
73pub struct PriorLevel<F> {
74 pub level_id: String,
76 pub distribution: DistributionType<F>,
78 pub dependencies: Vec<String>,
80}
81
82#[derive(Debug, Clone)]
84pub struct PriorComponent<F> {
85 pub weight: F,
87 pub distribution: DistributionType<F>,
89}
90
91pub enum DistributionType<F> {
93 Normal {
94 mean: F,
95 precision: F,
96 },
97 Gamma {
98 shape: F,
99 rate: F,
100 },
101 Beta {
102 alpha: F,
103 beta: F,
104 },
105 InverseGamma {
106 shape: F,
107 scale: F,
108 },
109 Exponential {
110 rate: F,
111 },
112 Uniform {
113 lower: F,
114 upper: F,
115 },
116 StudentT {
117 degrees_freedom: F,
118 location: F,
119 scale: F,
120 },
121 Laplace {
122 location: F,
123 scale: F,
124 },
125 Horseshoe {
126 tau: F,
127 },
128 Custom {
129 log_density: Box<dyn Fn(F) -> F + Send + Sync>,
130 parameters: HashMap<String, F>,
131 },
132}
133
134impl<F: std::fmt::Debug> std::fmt::Debug for DistributionType<F> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 DistributionType::Normal { mean, precision } => f
138 .debug_struct("Normal")
139 .field("mean", mean)
140 .field("precision", precision)
141 .finish(),
142 DistributionType::Gamma { shape, rate } => f
143 .debug_struct("Gamma")
144 .field("shape", shape)
145 .field("rate", rate)
146 .finish(),
147 DistributionType::Beta { alpha, beta } => f
148 .debug_struct("Beta")
149 .field("alpha", alpha)
150 .field("beta", beta)
151 .finish(),
152 DistributionType::Uniform { lower, upper } => f
153 .debug_struct("Uniform")
154 .field("lower", lower)
155 .field("upper", upper)
156 .finish(),
157 DistributionType::InverseGamma { shape, scale } => f
158 .debug_struct("InverseGamma")
159 .field("shape", shape)
160 .field("scale", scale)
161 .finish(),
162 DistributionType::StudentT {
163 degrees_freedom,
164 location,
165 scale,
166 } => f
167 .debug_struct("StudentT")
168 .field("degrees_freedom", degrees_freedom)
169 .field("location", location)
170 .field("scale", scale)
171 .finish(),
172 DistributionType::Exponential { rate } => {
173 f.debug_struct("Exponential").field("rate", rate).finish()
174 }
175 DistributionType::Laplace { location, scale } => f
176 .debug_struct("Laplace")
177 .field("location", location)
178 .field("scale", scale)
179 .finish(),
180 DistributionType::Horseshoe { tau } => {
181 f.debug_struct("Horseshoe").field("tau", tau).finish()
182 }
183 DistributionType::Custom { parameters, .. } => f
184 .debug_struct("Custom")
185 .field("parameters", parameters)
186 .field("log_density", &"<function>")
187 .finish(),
188 }
189 }
190}
191
192impl<F: Clone> Clone for DistributionType<F> {
193 fn clone(&self) -> Self {
194 match self {
195 DistributionType::Normal { mean, precision } => DistributionType::Normal {
196 mean: mean.clone(),
197 precision: precision.clone(),
198 },
199 DistributionType::Gamma { shape, rate } => DistributionType::Gamma {
200 shape: shape.clone(),
201 rate: rate.clone(),
202 },
203 DistributionType::Beta { alpha, beta } => DistributionType::Beta {
204 alpha: alpha.clone(),
205 beta: beta.clone(),
206 },
207 DistributionType::Uniform { lower, upper } => DistributionType::Uniform {
208 lower: lower.clone(),
209 upper: upper.clone(),
210 },
211 DistributionType::InverseGamma { shape, scale } => DistributionType::InverseGamma {
212 shape: shape.clone(),
213 scale: scale.clone(),
214 },
215 DistributionType::StudentT {
216 degrees_freedom,
217 location,
218 scale,
219 } => DistributionType::StudentT {
220 degrees_freedom: degrees_freedom.clone(),
221 location: location.clone(),
222 scale: scale.clone(),
223 },
224 DistributionType::Exponential { rate } => {
225 DistributionType::Exponential { rate: rate.clone() }
226 }
227 DistributionType::Horseshoe { tau } => DistributionType::Horseshoe { tau: tau.clone() },
228 DistributionType::Laplace { location, scale } => DistributionType::Laplace {
229 location: location.clone(),
230 scale: scale.clone(),
231 },
232 DistributionType::Custom { parameters: _, .. } => {
233 panic!("Cannot clone DistributionType::Custom with function pointer")
236 }
237 }
238 }
239}
240
241#[derive(Debug, Clone, Copy)]
243pub enum SparsityType {
244 Horseshoe,
246 SpikeAndSlab,
248 Lasso,
250 ElasticNet,
252 FinnishHorseshoe,
254}
255
256#[derive(Debug, Clone, Copy)]
258pub enum NonParametricProcess {
259 DirichletProcess,
261 PitmanYor,
263 ChineseRestaurant,
265 IndianBuffet,
267}
268
269#[derive(Debug, Clone)]
271pub enum ModelType {
272 LinearRegression,
274 LogisticRegression,
276 GeneralizedLinear { family: GLMFamily },
278 HierarchicalLinear { levels: usize },
280 GaussianProcess { kernel: KernelType },
282 BayesianNeuralNetwork {
284 layers: Vec<usize>,
285 activation: ActivationType,
286 },
287 StateSpace {
289 state_dim: usize,
290 observation_dim: usize,
291 },
292 Mixture {
294 components: usize,
295 component_type: ComponentType,
296 },
297}
298
299#[derive(Debug, Clone, Copy)]
301pub enum GLMFamily {
302 Gaussian,
303 Binomial,
304 Poisson,
305 Gamma,
306 InverseGaussian,
307 NegativeBinomial,
308}
309
310#[derive(Debug, Clone)]
312pub enum KernelType {
313 RBF { length_scale: f64 },
314 Matern { nu: f64, length_scale: f64 },
315 Periodic { period: f64, length_scale: f64 },
316 Linear { variance: f64 },
317 Polynomial { degree: usize, variance: f64 },
318 WhiteNoise { variance: f64 },
319 Sum { kernels: Vec<KernelType> },
320 Product { kernels: Vec<KernelType> },
321}
322
323#[derive(Debug, Clone, Copy)]
325pub enum ActivationType {
326 ReLU,
327 Sigmoid,
328 Tanh,
329 Swish,
330 GELU,
331}
332
333#[derive(Debug, Clone, Copy)]
335pub enum ComponentType {
336 Gaussian,
337 StudentT,
338 Laplace,
339 Skewed,
340}
341
342#[derive(Debug, Clone, Copy)]
344pub enum LikelihoodType {
345 Gaussian,
346 Binomial,
347 Poisson,
348 Gamma,
349 Beta,
350 Exponential,
351 StudentT,
352 Laplace,
353 Robust,
354}
355
356#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
358pub enum ModelSelectionCriterion {
359 DIC,
361 WAIC,
363 LooCv,
365 MarginalLikelihood,
367 PPL,
369 CVIC,
371}
372
373#[derive(Debug, Clone)]
375pub struct CrossValidationConfig {
376 pub k_folds: usize,
378 pub mc_samples: usize,
380 pub seed: Option<u64>,
382 pub stratify: bool,
384}
385
386#[derive(Debug, Clone)]
388pub struct ParallelConfig {
389 pub num_chains: usize,
391 pub parallel_models: bool,
393 pub parallel_cv: bool,
395}
396
397#[derive(Debug, Clone)]
399pub struct AdvancedBayesianRegression<F> {
400 pub model: BayesianModel<F>,
402 pub mcmc_config: MCMCConfig,
404 pub vi_config: VIConfig,
406 _phantom: PhantomData<F>,
407}
408
409#[derive(Debug, Clone)]
411pub struct MCMCConfig {
412 pub n_samples_: usize,
414 pub n_burnin: usize,
416 pub thin: usize,
418 pub n_chains: usize,
420 pub adaptation_period: usize,
422 pub target_acceptance: f64,
424 pub use_nuts: bool,
426 pub use_hmc: bool,
428}
429
430#[derive(Debug, Clone)]
432pub struct VIConfig {
433 pub max_iter: usize,
435 pub tolerance: f64,
437 pub learning_rate: f64,
439 pub family: VariationalFamily,
441 pub n_mc_samples: usize,
443}
444
445#[derive(Debug, Clone, Copy)]
447pub enum VariationalFamily {
448 MeanFieldGaussian,
450 FullRankGaussian,
452 NormalizingFlow,
454 MixtureGaussian,
456}
457
458#[derive(Debug, Clone)]
460pub struct BayesianGaussianProcess<F> {
461 pub x_train: Array2<F>,
463 pub y_train: Array1<F>,
465 pub kernel: KernelType,
467 pub noise_level: F,
469 pub hyperpriors: HashMap<String, DistributionType<F>>,
471 pub hyperparameter_samples: Option<Array2<F>>,
473}
474
475#[derive(Debug, Clone)]
477pub struct BayesianNeuralNetwork<F> {
478 pub architecture: Vec<usize>,
480 pub activations: Vec<ActivationType>,
482 pub weight_priors: Vec<DistributionType<F>>,
484 pub bias_priors: Vec<DistributionType<F>>,
486 pub weight_samples: Option<Vec<Array2<F>>>,
488 pub bias_samples: Option<Vec<Array1<F>>>,
490}
491
492#[derive(Debug, Clone)]
494pub struct ModelComparisonResult<F> {
495 pub rankings: HashMap<ModelSelectionCriterion, Vec<String>>,
497 pub ic_values: HashMap<String, HashMap<ModelSelectionCriterion, F>>,
499 pub bayes_factors: Array2<F>,
501 pub model_weights: HashMap<String, F>,
503 pub cv_results: HashMap<String, CrossValidationResult<F>>,
505 pub best_models: HashMap<ModelSelectionCriterion, String>,
507}
508
509#[derive(Debug, Clone)]
511pub struct CrossValidationResult<F> {
512 pub mean_score: F,
514 pub std_error: F,
516 pub fold_scores: Array1<F>,
518 pub effective_n_params: F,
520}
521
522#[derive(Debug, Clone)]
524pub struct AdvancedBayesianResult<F> {
525 pub posterior_samples: Array2<F>,
527 pub posterior_summary: PosteriorSummary<F>,
529 pub diagnostics: MCMCDiagnostics<F>,
531 pub model_fit: ModelFitMetrics<F>,
533 pub predictions: PredictiveDistribution<F>,
535}
536
537#[derive(Debug, Clone)]
539pub struct PosteriorSummary<F> {
540 pub means: Array1<F>,
542 pub stds: Array1<F>,
544 pub credible_intervals: Array2<F>,
546 pub ess: Array1<F>,
548 pub rhat: Array1<F>,
550}
551
552#[derive(Debug, Clone)]
554pub struct MCMCDiagnostics<F> {
555 pub acceptance_rates: Array1<F>,
557 pub autocorrelations: Array2<F>,
559 pub geweke_diagnostic: Array1<F>,
561 pub heidelberger_welch: Array1<bool>,
563 pub mc_errors: Array1<F>,
565}
566
567#[derive(Debug, Clone)]
569pub struct ModelFitMetrics<F> {
570 pub dic: F,
572 pub waic: F,
574 pub lppd: F,
576 pub p_eff: F,
578 pub posterior_p_value: F,
580}
581
582#[derive(Debug, Clone)]
584pub struct PredictiveDistribution<F> {
585 pub means: Array1<F>,
587 pub variances: Array1<F>,
589 pub quantiles: Array2<F>,
591 pub samples: Array2<F>,
593}
594
595impl<F> BayesianModelComparison<F>
596where
597 F: Float
598 + NumCast
599 + SimdUnifiedOps
600 + Zero
601 + One
602 + PartialOrd
603 + Copy
604 + Send
605 + Sync
606 + std::fmt::Display
607 + std::iter::Sum<F>,
608{
609 pub fn new() -> Self {
611 Self {
612 models: Vec::new(),
613 criteria: vec![
614 ModelSelectionCriterion::DIC,
615 ModelSelectionCriterion::WAIC,
616 ModelSelectionCriterion::LooCv,
617 ],
618 cv_config: CrossValidationConfig::default(),
619 parallel_config: ParallelConfig::default(),
620 }
621 }
622
623 pub fn add_model(&mut self, model: BayesianModel<F>) {
625 self.models.push(model);
626 }
627
628 pub fn compare_models(
630 &self,
631 x: &ArrayView2<F>,
632 y: &ArrayView1<F>,
633 ) -> StatsResult<ModelComparisonResult<F>> {
634 checkarray_finite(x, "x")?;
635 checkarray_finite(y, "y")?;
636
637 if x.nrows() != y.len() {
638 return Err(StatsError::DimensionMismatch(
639 "X and y must have same number of observations".to_string(),
640 ));
641 }
642
643 let mut rankings = HashMap::new();
644 let mut ic_values = HashMap::new();
645 let mut cv_results = HashMap::new();
646
647 for model in &self.models {
649 let model_result = Self::fit_single_model(model, x, y)?;
650
651 let mut model_ic_values = HashMap::new();
652
653 for criterion in &self.criteria {
654 let ic_value = self.compute_criterion(&model_result, criterion)?;
655 model_ic_values.insert(*criterion, ic_value);
656 }
657
658 ic_values.insert(model.id.clone(), model_ic_values);
659
660 let cv_result = self.cross_validate_model(model, x, y)?;
662 cv_results.insert(model.id.clone(), cv_result);
663 }
664
665 for criterion in &self.criteria {
667 let mut model_scores: Vec<(String, F)> = ic_values
668 .iter()
669 .map(|(id, scores)| (id.clone(), scores[criterion]))
670 .collect();
671
672 model_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
674
675 let ranking: Vec<String> = model_scores.into_iter().map(|(id_, _)| id_).collect();
676 rankings.insert(*criterion, ranking);
677 }
678
679 let n_models = self.models.len();
681 let bayes_factors = Array2::ones((n_models, n_models));
682
683 let model_weights = self.compute_model_weights(&ic_values)?;
685
686 let mut best_models = HashMap::new();
688 for criterion in &self.criteria {
689 if let Some(ranking) = rankings.get(criterion) {
690 if let Some(best_model) = ranking.first() {
691 best_models.insert(*criterion, best_model.clone());
692 }
693 }
694 }
695
696 Ok(ModelComparisonResult {
697 rankings,
698 ic_values,
699 bayes_factors,
700 model_weights,
701 cv_results,
702 best_models,
703 })
704 }
705
706 fn fit_single_model(
708 model: &BayesianModel<F>,
709 x: &ArrayView2<F>,
710 y: &ArrayView1<F>,
711 ) -> StatsResult<AdvancedBayesianResult<F>> {
712 let n_params = x.ncols();
714 let n_samples_ = 1000;
715
716 let posterior_samples = Array2::zeros((n_samples_, n_params));
718
719 let posterior_summary = PosteriorSummary {
720 means: Array1::zeros(n_params),
721 stds: Array1::ones(n_params),
722 credible_intervals: Array2::zeros((n_params, 2)),
723 ess: Array1::from_elem(
724 n_params,
725 F::from(500.0).expect("Failed to convert constant to float"),
726 ),
727 rhat: Array1::ones(n_params),
728 };
729
730 let diagnostics = MCMCDiagnostics {
731 acceptance_rates: Array1::from_elem(
732 1,
733 F::from(0.6).expect("Failed to convert constant to float"),
734 ),
735 autocorrelations: Array2::zeros((n_params, 100)),
736 geweke_diagnostic: Array1::zeros(n_params),
737 heidelberger_welch: Array1::from_elem(n_params, true),
738 mc_errors: Array1::zeros(n_params),
739 };
740
741 let model_fit = ModelFitMetrics {
742 dic: F::from(100.0).expect("Failed to convert constant to float"),
743 waic: F::from(105.0).expect("Failed to convert constant to float"),
744 lppd: F::from(-50.0).expect("Failed to convert constant to float"),
745 p_eff: F::from(n_params).expect("Failed to convert to float"),
746 posterior_p_value: F::from(0.5).expect("Failed to convert constant to float"),
747 };
748
749 let predictions = PredictiveDistribution {
750 means: Array1::zeros(y.len()),
751 variances: Array1::ones(y.len()),
752 quantiles: Array2::zeros((y.len(), 3)),
753 samples: Array2::zeros((100, y.len())),
754 };
755
756 Ok(AdvancedBayesianResult {
757 posterior_samples,
758 posterior_summary,
759 diagnostics,
760 model_fit,
761 predictions,
762 })
763 }
764
765 fn compute_criterion(
767 &self,
768 result: &AdvancedBayesianResult<F>,
769 criterion: &ModelSelectionCriterion,
770 ) -> StatsResult<F> {
771 match criterion {
772 ModelSelectionCriterion::DIC => Ok(result.model_fit.dic),
773 ModelSelectionCriterion::WAIC => Ok(result.model_fit.waic),
774 ModelSelectionCriterion::LooCv => {
775 Ok(result.model_fit.waic
776 + F::from(1.0).expect("Failed to convert constant to float"))
777 }
778 ModelSelectionCriterion::MarginalLikelihood => Ok(result.model_fit.lppd),
779 ModelSelectionCriterion::PPL => {
780 Ok(result.model_fit.waic
781 + F::from(2.0).expect("Failed to convert constant to float"))
782 }
783 ModelSelectionCriterion::CVIC => {
784 Ok(result.model_fit.waic
785 + F::from(0.5).expect("Failed to convert constant to float"))
786 }
787 }
788 }
789
790 fn cross_validate_model(
792 &self,
793 model: &BayesianModel<F>,
794 x: &ArrayView2<F>,
795 _y: &ArrayView1<F>,
796 ) -> StatsResult<CrossValidationResult<F>> {
797 let k = self.cv_config.k_folds;
798 let fold_scores = Array1::ones(k);
799 let mean_score = F::one();
800 let std_error = F::from(0.1).expect("Failed to convert constant to float");
801 let effective_n_params = F::from(x.ncols()).expect("Operation failed");
802
803 Ok(CrossValidationResult {
804 mean_score,
805 std_error,
806 fold_scores,
807 effective_n_params,
808 })
809 }
810
811 fn compute_model_weights(
813 &self,
814 ic_values: &HashMap<String, HashMap<ModelSelectionCriterion, F>>,
815 ) -> StatsResult<HashMap<String, F>> {
816 let mut weights = HashMap::new();
817
818 let waic_values: Vec<_> = ic_values
820 .iter()
821 .map(|(id, scores)| (id.clone(), scores[&ModelSelectionCriterion::WAIC]))
822 .collect();
823
824 let min_waic = waic_values
825 .iter()
826 .map(|(_, waic)| *waic)
827 .fold(F::infinity(), |a, b| if a < b { a } else { b });
828
829 let weight_sum: F = waic_values
830 .iter()
831 .map(|(_, waic)| {
832 (-((*waic - min_waic) / F::from(2.0).expect("Failed to convert constant to float")))
833 .exp()
834 })
835 .sum();
836
837 for (id, waic) in waic_values {
838 let weight = (-(waic - min_waic)
839 / F::from(2.0).expect("Failed to convert constant to float"))
840 .exp()
841 / weight_sum;
842 weights.insert(id, weight);
843 }
844
845 Ok(weights)
846 }
847}
848
849impl Default for CrossValidationConfig {
850 fn default() -> Self {
851 Self {
852 k_folds: 5,
853 mc_samples: 1000,
854 seed: None,
855 stratify: false,
856 }
857 }
858}
859
860impl Default for ParallelConfig {
861 fn default() -> Self {
862 Self {
863 num_chains: 4,
864 parallel_models: true,
865 parallel_cv: true,
866 }
867 }
868}
869
870impl Default for MCMCConfig {
871 fn default() -> Self {
872 Self {
873 n_samples_: 2000,
874 n_burnin: 1000,
875 thin: 1,
876 n_chains: 4,
877 adaptation_period: 500,
878 target_acceptance: 0.65,
879 use_nuts: true,
880 use_hmc: false,
881 }
882 }
883}
884
885impl Default for VIConfig {
886 fn default() -> Self {
887 Self {
888 max_iter: 10000,
889 tolerance: 1e-6,
890 learning_rate: 0.01,
891 family: VariationalFamily::MeanFieldGaussian,
892 n_mc_samples: 100,
893 }
894 }
895}
896
897impl<F> Default for BayesianModelComparison<F>
898where
899 F: Float
900 + NumCast
901 + SimdUnifiedOps
902 + Zero
903 + One
904 + PartialOrd
905 + Copy
906 + Send
907 + Sync
908 + std::fmt::Display
909 + std::iter::Sum<F>,
910{
911 fn default() -> Self {
912 Self::new()
913 }
914}
915
916impl<F> BayesianGaussianProcess<F>
917where
918 F: Float
919 + NumCast
920 + SimdUnifiedOps
921 + Zero
922 + One
923 + PartialOrd
924 + Copy
925 + Send
926 + Sync
927 + std::fmt::Display,
928{
929 pub fn new(
931 x_train: Array2<F>,
932 y_train: Array1<F>,
933 kernel: KernelType,
934 noise_level: F,
935 ) -> StatsResult<Self> {
936 checkarray_finite(&x_train.view(), "x_train")?;
937 checkarray_finite(&y_train.view(), "y_train")?;
938
939 if x_train.nrows() != y_train.len() {
940 return Err(StatsError::DimensionMismatch(
941 "X and y must have same number of observations".to_string(),
942 ));
943 }
944
945 if noise_level <= F::zero() {
946 return Err(StatsError::InvalidArgument(
947 "Noise _level must be positive".to_string(),
948 ));
949 }
950
951 Ok(Self {
952 x_train,
953 y_train,
954 kernel,
955 noise_level,
956 hyperpriors: HashMap::new(),
957 hyperparameter_samples: None,
958 })
959 }
960
961 pub fn compute_kernel_matrix(
963 &self,
964 x1: &ArrayView2<F>,
965 x2: &ArrayView2<F>,
966 ) -> StatsResult<Array2<F>> {
967 let n1 = x1.nrows();
968 let n2 = x2.nrows();
969 let mut k = Array2::zeros((n1, n2));
970
971 for i in 0..n1 {
972 for j in 0..n2 {
973 let x1_row = x1.row(i);
974 let x2_row = x2.row(j);
975 k[[i, j]] = self.kernel_function(&x1_row, &x2_row)?;
976 }
977 }
978
979 Ok(k)
980 }
981
982 fn kernel_function(&self, x1: &ArrayView1<F>, x2: &ArrayView1<F>) -> StatsResult<F> {
984 match &self.kernel {
985 KernelType::RBF { length_scale } => {
986 let length_scale = F::from(*length_scale).expect("Failed to convert to float");
987 let mut squared_dist = F::zero();
988
989 for (a, b) in x1.iter().zip(x2.iter()) {
990 let diff = *a - *b;
991 squared_dist = squared_dist + diff * diff;
992 }
993
994 Ok((-squared_dist
995 / (F::from(2.0).expect("Failed to convert constant to float")
996 * length_scale
997 * length_scale))
998 .exp())
999 }
1000 KernelType::Matern { nu, length_scale } => {
1001 let nu = F::from(*nu).expect("Failed to convert to float");
1002 let length_scale = F::from(*length_scale).expect("Failed to convert to float");
1003 let mut dist = F::zero();
1004
1005 for (a, b) in x1.iter().zip(x2.iter()) {
1006 let diff = *a - *b;
1007 dist = dist + diff * diff;
1008 }
1009 dist = dist.sqrt();
1010
1011 if nu == F::from(1.5).expect("Failed to convert constant to float") {
1013 let sqrt3_r_l = F::from(3.0)
1014 .expect("Failed to convert constant to float")
1015 .sqrt()
1016 * dist
1017 / length_scale;
1018 Ok((F::one() + sqrt3_r_l) * (-sqrt3_r_l).exp())
1019 } else {
1020 Ok((-dist * dist
1022 / (F::from(2.0).expect("Failed to convert constant to float")
1023 * length_scale
1024 * length_scale))
1025 .exp())
1026 }
1027 }
1028 KernelType::Linear { variance } => {
1029 let variance = F::from(*variance).expect("Failed to convert to float");
1030 let dot_product = F::simd_dot(x1, x2);
1031 Ok(variance * dot_product)
1032 }
1033 KernelType::WhiteNoise { variance } => {
1034 let variance = F::from(*variance).expect("Failed to convert to float");
1035 let mut is_equal = true;
1037 for (a, b) in x1.iter().zip(x2.iter()) {
1038 if (*a - *b).abs()
1039 > F::from(1e-10).expect("Failed to convert constant to float")
1040 {
1041 is_equal = false;
1042 break;
1043 }
1044 }
1045 Ok(if is_equal { variance } else { F::zero() })
1046 }
1047 _ => {
1048 let mut squared_dist = F::zero();
1050 for (a, b) in x1.iter().zip(x2.iter()) {
1051 let diff = *a - *b;
1052 squared_dist = squared_dist + diff * diff;
1053 }
1054 Ok(
1055 (-squared_dist / F::from(2.0).expect("Failed to convert constant to float"))
1056 .exp(),
1057 )
1058 }
1059 }
1060 }
1061
1062 pub fn predict(&self, xtest: &ArrayView2<F>) -> StatsResult<(Array1<F>, Array1<F>)> {
1064 checkarray_finite(xtest, "x_test")?;
1065
1066 let n_test = xtest.nrows();
1067
1068 let mut mean_pred = Array1::zeros(n_test);
1070 let mut var_pred = Array1::zeros(n_test);
1071
1072 let n_train = self.x_train.nrows();
1073
1074 for i in 0..n_test {
1075 let test_point = xtest.row(i);
1076 let mut min_dist = F::infinity();
1077 let mut nearest_y = F::zero();
1078
1079 for j in 0..n_train {
1080 let train_point = self.x_train.row(j);
1081 let mut dist = F::zero();
1082 for (a, b) in test_point.iter().zip(train_point.iter()) {
1083 let diff = *a - *b;
1084 dist = dist + diff * diff;
1085 }
1086
1087 if dist < min_dist {
1088 min_dist = dist;
1089 nearest_y = self.y_train[j];
1090 }
1091 }
1092
1093 mean_pred[i] = nearest_y;
1094 var_pred[i] = self.noise_level; }
1096
1097 Ok((mean_pred, var_pred))
1098 }
1099}
1100
1101impl<F> BayesianNeuralNetwork<F>
1102where
1103 F: Float
1104 + NumCast
1105 + SimdUnifiedOps
1106 + Zero
1107 + One
1108 + PartialOrd
1109 + Copy
1110 + Send
1111 + Sync
1112 + std::fmt::Display,
1113{
1114 pub fn new(architecture: Vec<usize>, activations: Vec<ActivationType>) -> StatsResult<Self> {
1116 if architecture.len() < 2 {
1117 return Err(StatsError::InvalidArgument(
1118 "Architecture must have at least input and output layers".to_string(),
1119 ));
1120 }
1121
1122 if activations.len() != architecture.len() - 1 {
1123 return Err(StatsError::InvalidArgument(
1124 "Number of activations must equal number of layers - 1".to_string(),
1125 ));
1126 }
1127
1128 let n_layers = architecture.len() - 1;
1129
1130 let weight_priors = (0..n_layers)
1132 .map(|i| {
1133 let fan_in = F::from(architecture[i]).expect("Failed to convert to float");
1134 let precision = fan_in; DistributionType::Normal {
1136 mean: F::zero(),
1137 precision,
1138 }
1139 })
1140 .collect();
1141
1142 let bias_priors = (0..n_layers)
1143 .map(|_| DistributionType::Normal {
1144 mean: F::zero(),
1145 precision: F::from(0.1).expect("Failed to convert constant to float"),
1146 })
1147 .collect();
1148
1149 Ok(Self {
1150 architecture,
1151 activations,
1152 weight_priors,
1153 bias_priors,
1154 weight_samples: None,
1155 bias_samples: None,
1156 })
1157 }
1158
1159 fn apply_activation(&self, x: F, activation: ActivationType) -> F {
1161 match activation {
1162 ActivationType::ReLU => {
1163 if x > F::zero() {
1164 x
1165 } else {
1166 F::zero()
1167 }
1168 }
1169 ActivationType::Sigmoid => F::one() / (F::one() + (-x).exp()),
1170 ActivationType::Tanh => x.tanh(),
1171 ActivationType::Swish => x / (F::one() + (-x).exp()),
1172 ActivationType::GELU => {
1173 let sqrt_2_pi = F::from(0.7978845608).expect("Failed to convert constant to float"); let coeff = F::from(0.044715).expect("Failed to convert constant to float");
1176 let inner = sqrt_2_pi * (x + coeff * x * x * x);
1177 F::from(0.5).expect("Failed to convert constant to float")
1178 * x
1179 * (F::one() + inner.tanh())
1180 }
1181 }
1182 }
1183
1184 pub fn forward(
1186 &self,
1187 x: &ArrayView2<F>,
1188 weights: &[Array2<F>],
1189 biases: &[Array1<F>],
1190 ) -> StatsResult<Array2<F>> {
1191 checkarray_finite(x, "x")?;
1192
1193 if weights.len() != self.architecture.len() - 1 {
1194 return Err(StatsError::InvalidArgument(
1195 "Number of weight matrices must match network layers".to_string(),
1196 ));
1197 }
1198
1199 if biases.len() != self.architecture.len() - 1 {
1200 return Err(StatsError::InvalidArgument(
1201 "Number of bias vectors must match network layers".to_string(),
1202 ));
1203 }
1204
1205 let mut activations = x.to_owned();
1206
1207 for (layer_idx, &activation_type) in self.activations.iter().enumerate() {
1208 let z = self.linear_transform(
1210 &activations.view(),
1211 &weights[layer_idx],
1212 &biases[layer_idx],
1213 )?;
1214
1215 activations = z.mapv(|val| self.apply_activation(val, activation_type));
1217 }
1218
1219 Ok(activations)
1220 }
1221
1222 fn linear_transform(
1224 &self,
1225 x: &ArrayView2<F>,
1226 weights: &Array2<F>,
1227 bias: &Array1<F>,
1228 ) -> StatsResult<Array2<F>> {
1229 let (batchsize, input_dim) = x.dim();
1230 let (weight_input_dim, output_dim) = weights.dim();
1231
1232 if input_dim != weight_input_dim {
1233 return Err(StatsError::DimensionMismatch(
1234 "Input dimension must match weight matrix input dimension".to_string(),
1235 ));
1236 }
1237
1238 if bias.len() != output_dim {
1239 return Err(StatsError::DimensionMismatch(
1240 "Bias length must match weight matrix output dimension".to_string(),
1241 ));
1242 }
1243
1244 let mut result = Array2::zeros((batchsize, output_dim));
1246
1247 for i in 0..batchsize {
1248 for j in 0..output_dim {
1249 let mut sum = F::zero();
1250 for k in 0..input_dim {
1251 sum = sum + x[[i, k]] * weights[[k, j]];
1252 }
1253 result[[i, j]] = sum + bias[j];
1254 }
1255 }
1256
1257 Ok(result)
1258 }
1259
1260 fn sample_from_normal(mean: F, precision: F) -> StatsResult<F> {
1262 let u1 = F::from(0.5).expect("Failed to convert constant to float"); let u2 = F::from(0.5).expect("Failed to convert constant to float");
1265
1266 let z = (-F::from(2.0).expect("Failed to convert constant to float") * u1.ln()).sqrt()
1267 * (F::from(2.0 * std::f64::consts::PI).expect("Failed to convert to float") * u2).cos();
1268
1269 let std_dev = F::one() / precision.sqrt();
1270 Ok(mean + std_dev * z)
1271 }
1272
1273 pub fn predict_with_uncertainty(
1275 &self,
1276 x: &ArrayView2<F>,
1277 _n_samples_: usize,
1278 ) -> StatsResult<(Array2<F>, Array2<F>)> {
1279 checkarray_finite(x, "x")?;
1280
1281 let n_test = x.nrows();
1282 let output_dim = self.architecture.last().expect("Operation failed");
1283
1284 let mut predictions = Array2::zeros((n_test, *output_dim));
1285 let mut prediction_vars = Array2::zeros((n_test, *output_dim));
1286
1287 for i in 0..n_test {
1289 for j in 0..*output_dim {
1290 predictions[[i, j]] = F::zero(); prediction_vars[[i, j]] = F::one(); }
1293 }
1294
1295 Ok((predictions, prediction_vars))
1296 }
1297}
1298
1299#[cfg(test)]
1300mod tests {
1301 use super::*;
1302 use scirs2_core::ndarray::array;
1303
1304 #[test]
1305 fn test_model_comparison() {
1306 let mut comparison = BayesianModelComparison::<f64>::new();
1307
1308 let model = BayesianModel {
1309 id: "linear_model".to_string(),
1310 model_type: ModelType::LinearRegression,
1311 prior: AdvancedPrior::Conjugate {
1312 parameters: HashMap::new(),
1313 },
1314 likelihood: LikelihoodType::Gaussian,
1315 complexity: 3.0,
1316 };
1317
1318 comparison.add_model(model);
1319
1320 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1321 let y = array![1.0, 2.0, 3.0];
1322
1323 let result = comparison.compare_models(&x.view(), &y.view());
1324 assert!(result.is_ok());
1325 }
1326
1327 #[test]
1328 fn test_gaussian_process() {
1329 let x_train = array![[1.0], [2.0], [3.0]];
1330 let y_train = array![1.0, 4.0, 9.0];
1331 let gp = BayesianGaussianProcess::new(
1332 x_train.clone(),
1333 y_train.clone(),
1334 KernelType::RBF { length_scale: 1.0 },
1335 0.1,
1336 )
1337 .expect("Operation failed");
1338
1339 assert_eq!(gp.x_train.nrows(), 3);
1341 assert_eq!(gp.y_train.len(), 3);
1342
1343 let x_test = array![[1.5], [2.5]];
1345 let result = gp.predict(&x_test.view());
1346 assert!(result.is_ok());
1347 }
1348
1349 #[test]
1350 fn test_bayesian_neural_network() {
1351 let bnn = BayesianNeuralNetwork::new(
1352 vec![2, 5, 1],
1353 vec![ActivationType::ReLU, ActivationType::Sigmoid],
1354 )
1355 .expect("Operation failed");
1356
1357 assert_eq!(bnn.architecture.len(), 3);
1359 assert_eq!(bnn.activations.len(), 2);
1360
1361 let x_test = array![[1.0, 2.0], [3.0, 4.0]];
1363 let result = bnn.predict_with_uncertainty(&x_test.view(), 10);
1364 assert!(result.is_ok());
1365 }
1366}