1use std::f64::consts::PI;
74
75use faer::linalg::{matmul, solvers::Solve};
76use faer::{Accum, Mat, Par, Side};
77
78use crate::gaussian::cholesky_logdet;
79use crate::{BLRError, Gaussian};
80
81#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
100pub struct BLRPrior {
101 pub mean: Vec<f64>,
103 pub cov: Vec<f64>,
105 pub alphas: Vec<f64>,
107}
108
109impl BLRPrior {
110 pub fn validate(&self) -> Result<(), BLRError> {
117 let d = self.mean.len();
118 if self.alphas.len() != d {
119 return Err(BLRError::DimMismatch {
120 expected: d,
121 got: self.alphas.len(),
122 });
123 }
124 if self.cov.len() != d * d {
125 return Err(BLRError::DimMismatch {
126 expected: d * d,
127 got: self.cov.len(),
128 });
129 }
130 if d == 0 {
131 return Err(BLRError::DimMismatch {
132 expected: 1,
133 got: 0,
134 });
135 }
136 let cov_mat = Mat::<f64>::from_fn(d, d, |i, j| self.cov[i * d + j]);
138 cov_mat
139 .llt(Side::Lower)
140 .map_err(|_| BLRError::SingularMatrix)?;
141 Ok(())
142 }
143}
144
145#[derive(Debug, Clone)]
152pub struct ArdConfig {
153 pub alpha_init: f64,
155 pub beta_init: f64,
157 pub max_iter: usize,
159 pub tol: f64,
161 pub update_beta: bool,
163}
164
165impl Default for ArdConfig {
166 fn default() -> Self {
167 Self {
168 alpha_init: 1.0,
169 beta_init: 1.0,
170 max_iter: 100,
171 tol: 1e-5,
172 update_beta: true,
173 }
174 }
175}
176
177pub struct PredictiveMarginals {
184 pub mean: Vec<f64>,
186 pub aleatoric_std: f64,
188 pub epistemic_std: Vec<f64>,
190 pub total_std: Vec<f64>,
192}
193
194pub struct FittedArd {
198 pub posterior: Gaussian,
200 pub alpha: Vec<f64>,
202 pub beta: f64,
204 pub log_evidences: Vec<f64>,
206 pub n_samples: usize,
208}
209
210impl FittedArd {
211 pub fn predict(
217 &self,
218 phi_test: &[f64],
219 n_test: usize,
220 n_features: usize,
221 ) -> PredictiveMarginals {
222 let d = n_features;
223 let sigma_mat = Mat::<f64>::from_fn(d, d, |i, j| self.posterior.cov[i * d + j]);
224 let mu_col = Mat::<f64>::from_fn(d, 1, |i, _| self.posterior.mean[i]);
225
226 let aleatoric_var = 1.0 / self.beta;
227 let aleatoric_std = aleatoric_var.sqrt();
228
229 let mut mean = Vec::with_capacity(n_test);
230 let mut epistemic_std = Vec::with_capacity(n_test);
231 let mut total_std = Vec::with_capacity(n_test);
232
233 for i in 0..n_test {
234 let phi_row = Mat::<f64>::from_fn(1, d, |_, j| phi_test[i * d + j]);
235
236 let mut m_mat = Mat::<f64>::zeros(1, 1);
238 matmul::matmul(
239 m_mat.as_mut(),
240 Accum::Replace,
241 phi_row.as_ref(),
242 mu_col.as_ref(),
243 1.0_f64,
244 Par::Seq,
245 );
246 mean.push(m_mat[(0, 0)]);
247
248 let mut sigma_phi_t = Mat::<f64>::zeros(d, 1);
250 matmul::matmul(
251 sigma_phi_t.as_mut(),
252 Accum::Replace,
253 sigma_mat.as_ref(),
254 phi_row.as_ref().transpose(),
255 1.0_f64,
256 Par::Seq,
257 );
258 let mut ep_var_mat = Mat::<f64>::zeros(1, 1);
259 matmul::matmul(
260 ep_var_mat.as_mut(),
261 Accum::Replace,
262 phi_row.as_ref(),
263 sigma_phi_t.as_ref(),
264 1.0_f64,
265 Par::Seq,
266 );
267 let ep_var = ep_var_mat[(0, 0)].max(0.0);
268 epistemic_std.push(ep_var.sqrt());
269 total_std.push((aleatoric_var + ep_var).sqrt());
270 }
271
272 PredictiveMarginals {
273 mean,
274 aleatoric_std,
275 epistemic_std,
276 total_std,
277 }
278 }
279
280 pub fn predict_gaussian(
284 &self,
285 phi_test: &[f64],
286 n_test: usize,
287 n_features: usize,
288 ) -> Result<Gaussian, BLRError> {
289 let d = n_features;
290 let m = n_test;
291
292 let phi_mat = Mat::<f64>::from_fn(m, d, |i, j| phi_test[i * d + j]);
293 let sigma_mat = Mat::<f64>::from_fn(d, d, |i, j| self.posterior.cov[i * d + j]);
294 let mu_col = Mat::<f64>::from_fn(d, 1, |i, _| self.posterior.mean[i]);
295
296 let mut pred_mean_mat = Mat::<f64>::zeros(m, 1);
298 matmul::matmul(
299 pred_mean_mat.as_mut(),
300 Accum::Replace,
301 phi_mat.as_ref(),
302 mu_col.as_ref(),
303 1.0_f64,
304 Par::Seq,
305 );
306
307 let mut tmp = Mat::<f64>::zeros(m, d);
310 matmul::matmul(
311 tmp.as_mut(),
312 Accum::Replace,
313 phi_mat.as_ref(),
314 sigma_mat.as_ref(),
315 1.0_f64,
316 Par::Seq,
317 );
318 let mut pred_cov = Mat::<f64>::zeros(m, m);
320 matmul::matmul(
321 pred_cov.as_mut(),
322 Accum::Replace,
323 tmp.as_ref(),
324 phi_mat.as_ref().transpose(),
325 1.0_f64,
326 Par::Seq,
327 );
328 let noise_var = 1.0 / self.beta;
330 for i in 0..m {
331 pred_cov[(i, i)] += noise_var + 1e-9; }
333
334 let pred_cov_ref = pred_cov.as_ref();
335 let pred_mean_vec: Vec<f64> = (0..m).map(|i| pred_mean_mat[(i, 0)]).collect();
336 let pred_cov_vec: Vec<f64> = (0..m)
337 .flat_map(|i| (0..m).map(move |j| pred_cov_ref[(i, j)]))
338 .collect();
339
340 Gaussian::new(pred_mean_vec, pred_cov_vec)
341 }
342
343 pub fn relevance(&self) -> Vec<f64> {
347 self.alpha.iter().map(|a| 1.0 / a).collect()
348 }
349
350 pub fn relevant_features(&self, threshold: Option<f64>) -> Vec<bool> {
357 let t = threshold.unwrap_or_else(|| {
358 let ln_mean = self.alpha.iter().map(|a| a.ln()).sum::<f64>() / self.alpha.len() as f64;
359 ln_mean.exp()
360 });
361 self.alpha.iter().map(|a| *a < t).collect()
362 }
363
364 pub fn noise_std(&self) -> f64 {
368 1.0 / self.beta.sqrt()
369 }
370
371 pub fn log_marginal_likelihood(&self) -> f64 {
373 *self.log_evidences.last().unwrap_or(&f64::NEG_INFINITY)
374 }
375
376 pub fn noise_precision(&self) -> f64 {
380 self.beta
381 }
382
383 pub fn posterior_covariance(&self) -> &[f64] {
385 &self.posterior.cov
386 }
387
388 pub fn sample_count(&self) -> usize {
390 self.n_samples
391 }
392
393 pub fn posterior_std(&self, phi_test: &[f64], n_test: usize, n_features: usize) -> Vec<f64> {
403 let d = n_features;
404 let sigma_cov = &self.posterior.cov;
405 let noise_var = 1.0 / self.beta.max(1e-10);
406 (0..n_test)
407 .map(|i| {
408 let phi_i = &phi_test[i * d..(i + 1) * d];
409 let mut sigma_phi = vec![0.0_f64; d];
410 for row in 0..d {
411 for col in 0..d {
412 sigma_phi[row] += sigma_cov[row * d + col] * phi_i[col];
413 }
414 }
415 let epistemic: f64 = phi_i.iter().zip(sigma_phi.iter()).map(|(a, b)| a * b).sum();
416 (noise_var + epistemic.max(0.0)).sqrt()
417 })
418 .collect()
419 }
420
421 pub fn posterior_std_grid(
431 &self,
432 input_range: (f64, f64),
433 resolution: usize,
434 feature_fn: &dyn Fn(f64) -> Vec<f64>,
435 ) -> (Vec<f64>, Vec<f64>) {
436 let d_sq = self.posterior.cov.len();
437 let d = (d_sq as f64).sqrt() as usize;
438 let resolution = resolution.max(2);
439 let step = (input_range.1 - input_range.0) / (resolution - 1) as f64;
440 let grid: Vec<f64> = (0..resolution)
441 .map(|k| input_range.0 + k as f64 * step)
442 .collect();
443 let mut phi_grid = Vec::with_capacity(resolution * d);
444 for &x in &grid {
445 let feats = feature_fn(x);
446 let actual = feats.len().min(d);
447 phi_grid.extend_from_slice(&feats[..actual]);
448 if actual < d {
449 phi_grid.extend(std::iter::repeat(0.0).take(d - actual));
450 }
451 }
452 let stds = self.posterior_std(&phi_grid, resolution, d);
453 (grid, stds)
454 }
455}
456
457fn log_evidence(
465 n: usize,
466 d: usize,
467 alpha: &[f64],
468 beta: f64,
469 mu: &[f64],
470 logdet_sigma_inv: f64,
471 residual_sq: f64,
472) -> f64 {
473 let log_alpha_sum: f64 = alpha.iter().map(|a| a.ln()).sum();
474 let mu_lambda_mu: f64 = alpha.iter().zip(mu.iter()).map(|(a, m)| a * m * m).sum();
475
476 0.5 * (log_alpha_sum + (n as f64) * beta.ln()
477 - logdet_sigma_inv
478 - beta * residual_sq
479 - mu_lambda_mu
480 + (d as f64) * (2.0 * PI).ln())
481 - 0.5 * (n as f64) * (2.0 * PI).ln()
482}
483
484pub fn fit(
498 phi: &[f64],
499 y: &[f64],
500 n: usize,
501 d: usize,
502 config: &ArdConfig,
503) -> Result<FittedArd, BLRError> {
504 if phi.len() != n * d {
505 return Err(BLRError::DimMismatch {
506 expected: n * d,
507 got: phi.len(),
508 });
509 }
510 if y.len() != n {
511 return Err(BLRError::DimMismatch {
512 expected: n,
513 got: y.len(),
514 });
515 }
516
517 let phi_mat = Mat::<f64>::from_fn(n, d, |i, j| phi[i * d + j]);
518 let y_mat = Mat::<f64>::from_fn(n, 1, |i, _| y[i]);
519
520 let mut phi_t_phi = Mat::<f64>::zeros(d, d);
522 matmul::matmul(
523 phi_t_phi.as_mut(),
524 Accum::Replace,
525 phi_mat.as_ref().transpose(),
526 phi_mat.as_ref(),
527 1.0_f64,
528 Par::Seq,
529 );
530
531 let mut phi_t_y = Mat::<f64>::zeros(d, 1);
532 matmul::matmul(
533 phi_t_y.as_mut(),
534 Accum::Replace,
535 phi_mat.as_ref().transpose(),
536 y_mat.as_ref(),
537 1.0_f64,
538 Par::Seq,
539 );
540
541 let mut alpha = vec![config.alpha_init; d];
543 let mut beta = config.beta_init;
544 let mut log_evidences: Vec<f64> = Vec::new();
545
546 let mut sigma_mat = Mat::<f64>::zeros(d, d);
548 let mut mu_vec = vec![0.0_f64; d];
549
550 for _iter in 0..config.max_iter {
551 let mut sigma_inv = Mat::<f64>::from_fn(d, d, |i, j| beta * phi_t_phi[(i, j)]);
554 for j in 0..d {
555 sigma_inv[(j, j)] += alpha[j];
556 }
557
558 let llt = sigma_inv
560 .llt(Side::Lower)
561 .map_err(|_| BLRError::SingularMatrix)?;
562
563 let eye = Mat::<f64>::identity(d, d);
565 sigma_mat = llt.solve(eye.as_ref());
566
567 let mut rhs = phi_t_y.clone();
569 for i in 0..d {
570 rhs[(i, 0)] *= beta;
571 }
572 let mu_mat = llt.solve(rhs.as_ref());
573 for i in 0..d {
574 mu_vec[i] = mu_mat[(i, 0)];
575 }
576
577 let logdet_sigma_inv = cholesky_logdet(&sigma_inv, d)?;
579
580 let mut phi_mu = Mat::<f64>::zeros(n, 1);
582 let mu_mat_ref = Mat::<f64>::from_fn(d, 1, |i, _| mu_vec[i]);
583 matmul::matmul(
584 phi_mu.as_mut(),
585 Accum::Replace,
586 phi_mat.as_ref(),
587 mu_mat_ref.as_ref(),
588 1.0_f64,
589 Par::Seq,
590 );
591 let residual_sq: f64 = (0..n)
592 .map(|i| {
593 let r = y[i] - phi_mu[(i, 0)];
594 r * r
595 })
596 .sum();
597
598 let gamma: Vec<f64> = (0..d).map(|j| 1.0 - alpha[j] * sigma_mat[(j, j)]).collect();
601
602 for j in 0..d {
604 alpha[j] = (gamma[j] / (mu_vec[j] * mu_vec[j] + 1e-10)).max(1e-8);
605 }
606
607 if config.update_beta {
609 let gamma_sum: f64 = gamma.iter().sum();
610 beta = ((n as f64 - gamma_sum) / (residual_sq + 1e-10)).max(1e-8);
611 }
612
613 let lml = log_evidence(n, d, &alpha, beta, &mu_vec, logdet_sigma_inv, residual_sq);
616 log_evidences.push(lml);
617
618 let n_ev = log_evidences.len();
620 let delta = if n_ev >= 4 {
621 let mean_curr = 0.5 * (log_evidences[n_ev - 1] + log_evidences[n_ev - 2]);
622 let mean_prev = 0.5 * (log_evidences[n_ev - 3] + log_evidences[n_ev - 4]);
623 (mean_curr - mean_prev).abs()
624 } else if n_ev >= 2 {
625 (log_evidences[n_ev - 1] - log_evidences[n_ev - 2]).abs()
626 } else {
627 f64::INFINITY
628 };
629
630 if delta < config.tol {
631 break;
632 }
633 }
634
635 let mu_final: Vec<f64> = mu_vec.clone();
637 let cov_final: Vec<f64> = {
638 let sigma_ref = sigma_mat.as_ref();
639 (0..d)
640 .flat_map(|i| (0..d).map(move |j| sigma_ref[(i, j)]))
641 .collect()
642 };
643 let posterior = Gaussian::new(mu_final, cov_final)?;
644
645 Ok(FittedArd {
646 posterior,
647 alpha,
648 beta,
649 log_evidences,
650 n_samples: n,
651 })
652}
653
654pub fn fit_with_prior(
677 phi: &[f64],
678 y: &[f64],
679 n: usize,
680 d: usize,
681 config: &ArdConfig,
682 prior: Option<&BLRPrior>,
683) -> Result<FittedArd, BLRError> {
684 if phi.len() != n * d {
685 return Err(BLRError::DimMismatch {
686 expected: n * d,
687 got: phi.len(),
688 });
689 }
690 if y.len() != n {
691 return Err(BLRError::DimMismatch {
692 expected: n,
693 got: y.len(),
694 });
695 }
696
697 if let Some(p) = prior {
699 p.validate()?;
700 if p.mean.len() != d {
701 return Err(BLRError::DimMismatch {
702 expected: d,
703 got: p.mean.len(),
704 });
705 }
706 }
707
708 let phi_mat = Mat::<f64>::from_fn(n, d, |i, j| phi[i * d + j]);
709 let y_mat = Mat::<f64>::from_fn(n, 1, |i, _| y[i]);
710
711 let mut phi_t_phi = Mat::<f64>::zeros(d, d);
713 matmul::matmul(
714 phi_t_phi.as_mut(),
715 Accum::Replace,
716 phi_mat.as_ref().transpose(),
717 phi_mat.as_ref(),
718 1.0_f64,
719 Par::Seq,
720 );
721
722 let mut phi_t_y = Mat::<f64>::zeros(d, 1);
723 matmul::matmul(
724 phi_t_y.as_mut(),
725 Accum::Replace,
726 phi_mat.as_ref().transpose(),
727 y_mat.as_ref(),
728 1.0_f64,
729 Par::Seq,
730 );
731
732 let mut alpha: Vec<f64> = prior
734 .map(|p| p.alphas.clone())
735 .unwrap_or_else(|| vec![config.alpha_init; d]);
736 let mut beta = config.beta_init;
737 let mut log_evidences: Vec<f64> = Vec::new();
738
739 let mut sigma_mat = Mat::<f64>::zeros(d, d);
741 let mut mu_vec: Vec<f64> = prior
743 .map(|p| p.mean.clone())
744 .unwrap_or_else(|| vec![0.0f64; d]);
745
746 for _iter in 0..config.max_iter {
747 let mut sigma_inv = Mat::<f64>::from_fn(d, d, |i, j| beta * phi_t_phi[(i, j)]);
750 for j in 0..d {
751 sigma_inv[(j, j)] += alpha[j];
752 }
753
754 let llt = sigma_inv
756 .llt(Side::Lower)
757 .map_err(|_| BLRError::SingularMatrix)?;
758
759 let eye = Mat::<f64>::identity(d, d);
761 sigma_mat = llt.solve(eye.as_ref());
762
763 let mut rhs = phi_t_y.clone();
765 for i in 0..d {
766 rhs[(i, 0)] *= beta;
767 }
768 let mu_mat = llt.solve(rhs.as_ref());
769 for i in 0..d {
770 mu_vec[i] = mu_mat[(i, 0)];
771 }
772
773 let logdet_sigma_inv = cholesky_logdet(&sigma_inv, d)?;
775
776 let mut phi_mu = Mat::<f64>::zeros(n, 1);
778 let mu_mat_ref = Mat::<f64>::from_fn(d, 1, |i, _| mu_vec[i]);
779 matmul::matmul(
780 phi_mu.as_mut(),
781 Accum::Replace,
782 phi_mat.as_ref(),
783 mu_mat_ref.as_ref(),
784 1.0_f64,
785 Par::Seq,
786 );
787 let residual_sq: f64 = (0..n)
788 .map(|i| {
789 let r = y[i] - phi_mu[(i, 0)];
790 r * r
791 })
792 .sum();
793
794 let gamma: Vec<f64> = (0..d).map(|j| 1.0 - alpha[j] * sigma_mat[(j, j)]).collect();
796
797 for j in 0..d {
798 alpha[j] = (gamma[j] / (mu_vec[j] * mu_vec[j] + 1e-10)).max(1e-8);
799 }
800
801 if config.update_beta {
802 let gamma_sum: f64 = gamma.iter().sum();
803 beta = ((n as f64 - gamma_sum) / (residual_sq + 1e-10)).max(1e-8);
804 }
805
806 let lml = log_evidence(n, d, &alpha, beta, &mu_vec, logdet_sigma_inv, residual_sq);
807 log_evidences.push(lml);
808
809 let n_ev = log_evidences.len();
810 let delta = if n_ev >= 4 {
811 let mean_curr = 0.5 * (log_evidences[n_ev - 1] + log_evidences[n_ev - 2]);
812 let mean_prev = 0.5 * (log_evidences[n_ev - 3] + log_evidences[n_ev - 4]);
813 (mean_curr - mean_prev).abs()
814 } else if n_ev >= 2 {
815 (log_evidences[n_ev - 1] - log_evidences[n_ev - 2]).abs()
816 } else {
817 f64::INFINITY
818 };
819
820 if delta < config.tol {
821 break;
822 }
823 }
824
825 let mu_final: Vec<f64> = mu_vec.clone();
826 let cov_final: Vec<f64> = {
827 let sigma_ref = sigma_mat.as_ref();
828 (0..d)
829 .flat_map(|i| (0..d).map(move |j| sigma_ref[(i, j)]))
830 .collect()
831 };
832 let posterior = Gaussian::new(mu_final, cov_final)?;
833
834 Ok(FittedArd {
835 posterior,
836 alpha,
837 beta,
838 log_evidences,
839 n_samples: n,
840 })
841}
842#[cfg(test)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_ard_config_defaults() {
848 let cfg = ArdConfig::default();
849 assert_eq!(cfg.alpha_init, 1.0);
850 assert_eq!(cfg.beta_init, 1.0);
851 assert_eq!(cfg.max_iter, 100);
852 assert_eq!(cfg.tol, 1e-5);
853 assert!(cfg.update_beta);
854 }
855
856 #[test]
857 fn test_log_evidence_helper() {
858 let lml = log_evidence(10, 3, &[1.0; 3], 1.0, &[0.0; 3], 5.0, 2.0);
860 assert!(lml.is_finite(), "log_evidence = {lml}");
861 }
862
863 #[test]
864 fn test_blr_prior_valid() {
865 let d = 3;
866 let prior = BLRPrior {
867 mean: vec![0.0; d],
868 cov: vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], alphas: vec![1.0; d],
870 };
871 assert!(prior.validate().is_ok());
872 }
873
874 #[test]
875 fn test_blr_prior_invalid_dimensions() {
876 let prior = BLRPrior {
877 mean: vec![0.0; 3],
878 cov: vec![1.0, 0.0, 0.0, 1.0], alphas: vec![1.0; 3],
880 };
881 assert!(prior.validate().is_err());
882 }
883
884 #[test]
885 fn test_blr_prior_not_psd() {
886 let d = 2;
887 let prior = BLRPrior {
888 mean: vec![0.0; d],
889 cov: vec![-1.0, 0.0, 0.0, -1.0], alphas: vec![1.0; d],
891 };
892 assert!(matches!(prior.validate(), Err(BLRError::SingularMatrix)));
893 }
894
895 #[test]
896 fn test_fit_with_prior_none_equals_fit() {
897 let phi: Vec<f64> = vec![1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125];
899 let y: Vec<f64> = vec![1.0, 2.0, 0.5];
900 let config = ArdConfig::default();
901
902 let r1 = fit(&phi, &y, 3, 3, &config).unwrap();
903 let r2 = fit_with_prior(&phi, &y, 3, 3, &config, None).unwrap();
904
905 assert_eq!(r1.alpha.len(), r2.alpha.len());
907 for (a1, a2) in r1.alpha.iter().zip(r2.alpha.iter()) {
909 assert!((a1 - a2).abs() < 1e-10, "alpha mismatch: {a1} vs {a2}");
910 }
911 assert!((r1.beta - r2.beta).abs() < 1e-10);
912 }
913
914 #[test]
915 fn test_fit_with_prior_some_compiles_and_runs() {
916 let d = 3;
917 let phi: Vec<f64> = vec![
918 1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125, 1.5, 0.75, 0.3, 0.8, 0.4, 0.2,
919 ];
920 let y: Vec<f64> = vec![1.0, 2.0, 0.5, 1.5, 0.8];
921 let config = ArdConfig::default();
922
923 let prior = BLRPrior {
924 mean: vec![0.5; d],
925 cov: vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
926 alphas: vec![0.5; d],
927 };
928 let result = fit_with_prior(&phi, &y, 5, d, &config, Some(&prior));
929 assert!(
930 result.is_ok(),
931 "fit_with_prior should succeed: {:?}",
932 result.err()
933 );
934 let fitted = result.unwrap();
935 assert!(fitted.noise_std() > 0.0);
936 assert_eq!(fitted.alpha.len(), d);
937 }
938
939 #[test]
940 fn test_fit_with_prior_convergence_faster() {
941 let d = 3;
943 let n = 5;
944 let phi: Vec<f64> = vec![
945 1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125, 1.5, 0.75, 0.3, 0.8, 0.4, 0.2,
946 ];
947 let y: Vec<f64> = vec![1.0, 2.0, 0.5, 1.5, 0.8];
948 let config = ArdConfig {
950 max_iter: 200,
951 tol: 1e-9,
952 ..ArdConfig::default()
953 };
954
955 let baseline = fit_with_prior(&phi, &y, n, d, &config, None).unwrap();
956
957 let prior = BLRPrior {
959 mean: baseline.posterior.mean.clone(),
960 cov: baseline.posterior.cov.clone(),
961 alphas: baseline.alpha.clone(),
962 };
963 let informed = fit_with_prior(&phi, &y, n, d, &config, Some(&prior)).unwrap();
964
965 assert!(informed.noise_std() > 0.0);
967 assert!(
969 informed.log_evidences.len() <= baseline.log_evidences.len(),
970 "informed iterations {} should be <= baseline iterations {}",
971 informed.log_evidences.len(),
972 baseline.log_evidences.len()
973 );
974 }
975}