1use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::validation::*;
9use scirs2_linalg;
10use statrs::statistics::Statistics;
11
12#[derive(Debug, Clone)]
19pub struct BayesianLinearRegression {
20 pub prior_mean: Array1<f64>,
22 pub prior_precision: Array2<f64>,
24 pub prior_alpha: f64,
26 pub prior_beta: f64,
28 pub fit_intercept: bool,
30}
31
32#[derive(Debug, Clone)]
34pub struct BayesianRegressionResult {
35 pub posterior_mean: Array1<f64>,
37 pub posterior_covariance: Array2<f64>,
39 pub posterior_alpha: f64,
41 pub posterior_beta: f64,
43 pub n_samples_: usize,
45 pub n_features: usize,
47 pub x_mean: Option<Array1<f64>>,
49 pub y_mean: Option<f64>,
51 pub log_marginal_likelihood: f64,
53}
54
55impl BayesianLinearRegression {
56 pub fn new(n_features: usize, fit_intercept: bool) -> StatsResult<Self> {
58 check_positive(n_features, "n_features")?;
59
60 let prior_mean = Array1::zeros(n_features);
62 let prior_precision = Array2::eye(n_features) * 1e-6; let prior_alpha = 1e-6; let prior_beta = 1e-6; Ok(Self {
67 prior_mean,
68 prior_precision,
69 prior_alpha,
70 prior_beta,
71 fit_intercept,
72 })
73 }
74
75 pub fn with_priors(
77 prior_mean: Array1<f64>,
78 prior_precision: Array2<f64>,
79 prior_alpha: f64,
80 prior_beta: f64,
81 fit_intercept: bool,
82 ) -> StatsResult<Self> {
83 checkarray_finite(&prior_mean, "prior_mean")?;
84 checkarray_finite(&prior_precision, "prior_precision")?;
85 check_positive(prior_alpha, "prior_alpha")?;
86 check_positive(prior_beta, "prior_beta")?;
87
88 if prior_precision.nrows() != prior_mean.len()
89 || prior_precision.ncols() != prior_mean.len()
90 {
91 return Err(StatsError::DimensionMismatch(format!(
92 "prior_precision shape ({}, {}) must match prior_mean length ({})",
93 prior_precision.nrows(),
94 prior_precision.ncols(),
95 prior_mean.len()
96 )));
97 }
98
99 Ok(Self {
100 prior_mean,
101 prior_precision,
102 prior_alpha,
103 prior_beta,
104 fit_intercept,
105 })
106 }
107
108 pub fn fit(
110 &self,
111 x: ArrayView2<f64>,
112 y: ArrayView1<f64>,
113 ) -> StatsResult<BayesianRegressionResult> {
114 checkarray_finite(&x, "x")?;
115 checkarray_finite(&y, "y")?;
116 let (n_samples_, n_features) = x.dim();
117
118 if y.len() != n_samples_ {
119 return Err(StatsError::DimensionMismatch(format!(
120 "y length ({}) must match x rows ({})",
121 y.len(),
122 n_samples_
123 )));
124 }
125
126 if n_samples_ < 2 {
127 return Err(StatsError::InvalidArgument(
128 "n_samples_ must be at least 2".to_string(),
129 ));
130 }
131
132 let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
134 let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
135 let y_mean = y.mean();
136
137 let mut x_centered = x.to_owned();
138 for mut row in x_centered.rows_mut() {
139 row -= &x_mean;
140 }
141
142 let y_centered = &y.to_owned() - y_mean;
143
144 (x_centered, y_centered, Some(x_mean), Some(y_mean))
145 } else {
146 (x.to_owned(), y.to_owned(), None, None)
147 };
148
149 let xtx = x_centered.t().dot(&x_centered);
151 let xty = x_centered.t().dot(&y_centered);
152
153 let posterior_precision = &self.prior_precision + &xtx;
155 let posterior_covariance =
156 scirs2_linalg::inv(&posterior_precision.view(), None).map_err(|e| {
157 StatsError::ComputationError(format!("Failed to invert posterior precision: {}", e))
158 })?;
159
160 let prior_contribution = self.prior_precision.dot(&self.prior_mean);
162 let data_contribution = &xty;
163 let posterior_mean = posterior_covariance.dot(&(&prior_contribution + data_contribution));
164
165 let posterior_alpha = self.prior_alpha + n_samples_ as f64 / 2.0;
167
168 let y_pred = x_centered.dot(&posterior_mean);
170 let residuals = &y_centered - &y_pred;
171 let rss = residuals.dot(&residuals);
172
173 let prior_quad_form = (&self.prior_mean - &posterior_mean).t().dot(
175 &self
176 .prior_precision
177 .dot(&(&self.prior_mean - &posterior_mean)),
178 );
179
180 let posterior_beta = self.prior_beta + 0.5 * (rss + prior_quad_form);
181
182 let log_marginal = self.compute_log_marginal_likelihood(
184 &x_centered,
185 &y_centered,
186 &posterior_precision,
187 posterior_alpha,
188 posterior_beta,
189 )?;
190
191 Ok(BayesianRegressionResult {
192 posterior_mean,
193 posterior_covariance,
194 posterior_alpha,
195 posterior_beta,
196 n_samples_,
197 n_features,
198 x_mean,
199 y_mean,
200 log_marginal_likelihood: log_marginal,
201 })
202 }
203
204 fn compute_log_marginal_likelihood(
206 &self,
207 x: &Array2<f64>,
208 _y: &Array1<f64>,
209 posterior_precision: &Array2<f64>,
210 posterior_alpha: f64,
211 posterior_beta: f64,
212 ) -> StatsResult<f64> {
213 let n = x.nrows() as f64;
214 let _p = x.ncols() as f64;
215
216 let prior_log_det =
218 scirs2_linalg::det(&self.prior_precision.view(), None).map_err(|e| {
219 StatsError::ComputationError(format!("Failed to compute prior determinant: {}", e))
220 })?;
221
222 let posterior_log_det =
223 scirs2_linalg::det(&posterior_precision.view(), None).map_err(|e| {
224 StatsError::ComputationError(format!(
225 "Failed to compute posterior determinant: {}",
226 e
227 ))
228 })?;
229
230 if prior_log_det <= 0.0 || posterior_log_det <= 0.0 {
231 return Err(StatsError::ComputationError(
232 "Precision matrices must be positive definite".to_string(),
233 ));
234 }
235
236 let gamma_ratio = gamma_log(posterior_alpha) - gamma_log(self.prior_alpha);
238
239 let log_ml = -0.5 * n * (2.0 * std::f64::consts::PI).ln() + 0.5 * prior_log_det.ln()
241 - 0.5 * posterior_log_det.ln()
242 + self.prior_alpha * self.prior_beta.ln()
243 - posterior_alpha * posterior_beta.ln()
244 + gamma_ratio;
245
246 Ok(log_ml)
247 }
248
249 pub fn predict(
251 &self,
252 x: ArrayView2<f64>,
253 result: &BayesianRegressionResult,
254 ) -> StatsResult<BayesianPredictionResult> {
255 checkarray_finite(&x, "x")?;
256 let (n_test, n_features) = x.dim();
257
258 if n_features != result.n_features {
259 return Err(StatsError::DimensionMismatch(format!(
260 "x has {} features, expected {}",
261 n_features, result.n_features
262 )));
263 }
264
265 let x_centered = if let Some(ref x_mean) = result.x_mean {
267 let mut x_c = x.to_owned();
268 for mut row in x_c.rows_mut() {
269 row -= x_mean;
270 }
271 x_c
272 } else {
273 x.to_owned()
274 };
275
276 let y_pred_centered = x_centered.dot(&result.posterior_mean);
278 let y_pred = if let Some(y_mean) = result.y_mean {
279 &y_pred_centered + y_mean
280 } else {
281 y_pred_centered.clone()
282 };
283
284 let noise_variance = result.posterior_beta / (result.posterior_alpha - 1.0);
286 let mut predictive_variance = Array1::zeros(n_test);
287
288 for i in 0..n_test {
289 let x_row = x_centered.row(i);
290 let model_variance = x_row.dot(&result.posterior_covariance.dot(&x_row));
291 predictive_variance[i] = noise_variance * (1.0 + model_variance);
292 }
293
294 let df = 2.0 * result.posterior_alpha;
296
297 Ok(BayesianPredictionResult {
298 mean: y_pred,
299 variance: predictive_variance,
300 degrees_of_freedom: df,
301 credible_interval: None,
302 })
303 }
304
305 pub fn predict_with_credible_interval(
307 &self,
308 x: ArrayView2<f64>,
309 result: &BayesianRegressionResult,
310 confidence: f64,
311 ) -> StatsResult<BayesianPredictionResult> {
312 check_probability(confidence, "confidence")?;
313
314 let mut pred_result = self.predict(x, result)?;
315
316 let alpha = (1.0 - confidence) / 2.0;
318 let df = pred_result.degrees_of_freedom;
319
320 let t_critical = if df > 30.0 {
322 normal_ppf(1.0 - alpha)?
324 } else {
325 t_ppf(1.0 - alpha, df)?
327 };
328
329 let mut lower_bounds = Array1::zeros(pred_result.mean.len());
330 let mut upper_bounds = Array1::zeros(pred_result.mean.len());
331
332 for i in 0..pred_result.mean.len() {
333 let std_err = pred_result.variance[i].sqrt();
334 lower_bounds[i] = pred_result.mean[i] - t_critical * std_err;
335 upper_bounds[i] = pred_result.mean[i] + t_critical * std_err;
336 }
337
338 pred_result.credible_interval = Some((lower_bounds, upper_bounds));
339 Ok(pred_result)
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct BayesianPredictionResult {
346 pub mean: Array1<f64>,
348 pub variance: Array1<f64>,
350 pub degrees_of_freedom: f64,
352 pub credible_interval: Option<(Array1<f64>, Array1<f64>)>,
354}
355
356#[derive(Debug, Clone)]
361pub struct ARDBayesianRegression {
362 pub max_iter: usize,
364 pub tol: f64,
366 pub alpha_init: Option<Array1<f64>>,
368 pub beta_init: f64,
370 pub fit_intercept: bool,
372}
373
374impl Default for ARDBayesianRegression {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380impl ARDBayesianRegression {
381 pub fn new() -> Self {
383 Self {
384 max_iter: 300,
385 tol: 1e-3,
386 alpha_init: None,
387 beta_init: 1.0,
388 fit_intercept: true,
389 }
390 }
391
392 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
394 self.max_iter = max_iter;
395 self
396 }
397
398 pub fn with_tolerance(mut self, tol: f64) -> Self {
400 self.tol = tol;
401 self
402 }
403
404 pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView1<f64>) -> StatsResult<ARDRegressionResult> {
406 checkarray_finite(&x, "x")?;
407 checkarray_finite(&y, "y")?;
408 let (n_samples_, n_features) = x.dim();
409
410 if y.len() != n_samples_ {
411 return Err(StatsError::DimensionMismatch(format!(
412 "y length ({}) must match x rows ({})",
413 y.len(),
414 n_samples_
415 )));
416 }
417
418 let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
420 let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
421 let y_mean = y.mean();
422
423 let mut x_centered = x.to_owned();
424 for mut row in x_centered.rows_mut() {
425 row -= &x_mean;
426 }
427
428 let y_centered = &y.to_owned() - y_mean;
429
430 (x_centered, y_centered, Some(x_mean), Some(y_mean))
431 } else {
432 (x.to_owned(), y.to_owned(), None, None)
433 };
434
435 let mut alpha = self
437 .alpha_init
438 .clone()
439 .unwrap_or_else(|| Array1::from_elem(n_features, 1.0));
440 let mut beta = self.beta_init;
441
442 let xtx = x_centered.t().dot(&x_centered);
443 let xty = x_centered.t().dot(&y_centered);
444
445 let mut prev_log_ml = f64::NEG_INFINITY;
446
447 for iteration in 0..self.max_iter {
448 let alpha_diag = Array2::from_diag(&alpha);
450 let precision = &alpha_diag + beta * &xtx;
451
452 let covariance = scirs2_linalg::inv(&precision.view(), None).map_err(|e| {
453 StatsError::ComputationError(format!("Failed to invert precision: {}", e))
454 })?;
455
456 let mean = beta * covariance.dot(&xty);
457
458 let mut new_alpha = Array1::zeros(n_features);
460 for i in 0..n_features {
461 let gamma_i = 1.0 - alpha[i] * covariance[[i, i]];
462 new_alpha[i] = gamma_i / (mean[i] * mean[i]);
463
464 if !new_alpha[i].is_finite() || new_alpha[i] < 1e-12 {
466 new_alpha[i] = 1e-12;
467 }
468 }
469
470 let y_pred = x_centered.dot(&mean);
472 let residuals = &y_centered - &y_pred;
473 let rss = residuals.dot(&residuals);
474
475 let _trace_cov = covariance.diag().sum();
476 let new_beta =
477 (n_samples_ as f64 - new_alpha.sum() + alpha.dot(&covariance.diag())) / rss;
478
479 let log_ml = self.compute_ard_log_marginal_likelihood(
481 &x_centered,
482 &y_centered,
483 &new_alpha,
484 new_beta,
485 )?;
486
487 if (log_ml - prev_log_ml).abs() < self.tol {
488 alpha = new_alpha;
489 beta = new_beta;
490 break;
491 }
492
493 alpha = new_alpha;
494 beta = new_beta;
495 prev_log_ml = log_ml;
496
497 if iteration == self.max_iter - 1 {
498 return Err(StatsError::ComputationError(format!(
499 "ARD failed to converge after {} iterations",
500 self.max_iter
501 )));
502 }
503 }
504
505 let alpha_diag = Array2::from_diag(&alpha);
507 let precision = &alpha_diag + beta * &xtx;
508 let covariance = scirs2_linalg::inv(&precision.view(), None).map_err(|e| {
509 StatsError::ComputationError(format!("Failed to compute final covariance: {}", e))
510 })?;
511 let mean = beta * covariance.dot(&xty);
512
513 Ok(ARDRegressionResult {
514 posterior_mean: mean,
515 posterior_covariance: covariance,
516 alpha,
517 beta,
518 n_samples_,
519 n_features,
520 x_mean,
521 y_mean,
522 log_marginal_likelihood: prev_log_ml,
523 })
524 }
525
526 fn compute_ard_log_marginal_likelihood(
528 &self,
529 x: &Array2<f64>,
530 y: &Array1<f64>,
531 alpha: &Array1<f64>,
532 beta: f64,
533 ) -> StatsResult<f64> {
534 let n = x.nrows() as f64;
535 let p = x.ncols() as f64;
536
537 let xtx = x.t().dot(x);
538 let xty = x.t().dot(y);
539
540 let alpha_diag = Array2::from_diag(alpha);
541 let precision = &alpha_diag + beta * &xtx;
542
543 let covariance = scirs2_linalg::inv(&precision.view(), None).map_err(|e| {
544 StatsError::ComputationError(format!("Failed to invert precision for log ML: {}", e))
545 })?;
546
547 let mean = beta * covariance.dot(&xty);
548
549 let log_det_precision = scirs2_linalg::det(&precision.view(), None).map_err(|e| {
551 StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
552 })?;
553
554 if log_det_precision <= 0.0 {
555 return Err(StatsError::ComputationError(
556 "Precision matrix must be positive definite".to_string(),
557 ));
558 }
559
560 let y_pred = x.dot(&mean);
562 let residuals = y - &y_pred;
563 let data_fit = beta * residuals.dot(&residuals);
564 let penalty = alpha
565 .iter()
566 .zip(mean.iter())
567 .map(|(&a, &m)| a * m * m)
568 .sum::<f64>();
569
570 let log_ml = 0.5
571 * (p * alpha.mapv(f64::ln).sum() + n * beta.ln() + log_det_precision.ln()
572 - n * (2.0 * std::f64::consts::PI).ln()
573 - data_fit
574 - penalty);
575
576 Ok(log_ml)
577 }
578}
579
580#[derive(Debug, Clone)]
582pub struct ARDRegressionResult {
583 pub posterior_mean: Array1<f64>,
585 pub posterior_covariance: Array2<f64>,
587 pub alpha: Array1<f64>,
589 pub beta: f64,
591 pub n_samples_: usize,
593 pub n_features: usize,
595 pub x_mean: Option<Array1<f64>>,
597 pub y_mean: Option<f64>,
599 pub log_marginal_likelihood: f64,
601}
602
603#[allow(dead_code)]
607fn gamma_log(x: f64) -> f64 {
608 if x <= 0.0 {
611 return f64::NEG_INFINITY;
612 }
613
614 if x < 1.0 {
615 return gamma_log(x + 1.0) - x.ln();
616 }
617
618 0.5 * (2.0 * std::f64::consts::PI).ln() + (x - 0.5) * x.ln() - x + 1.0 / (12.0 * x)
619}
620
621#[allow(dead_code)]
623fn normal_ppf(p: f64) -> StatsResult<f64> {
624 if p <= 0.0 || p >= 1.0 {
625 return Err(StatsError::InvalidArgument(
626 "p must be between 0 and 1".to_string(),
627 ));
628 }
629
630 let q = p - 0.5;
633 let result = if q.abs() < 0.5 {
634 let r = q * q;
635 let num =
636 (((-25.44106049637) * r + 41.39119773534) * r + (-18.61500062529)) * r + 2.50662823884;
637 let den = (((-7.784894002430) * r + 14.38718147627) * r + (-3.47396220392)) * r + 1.0;
638 q * num / den
639 } else {
640 let r = if q < 0.0 { p } else { 1.0 - p };
641 let num = (2.01033439929 * r.ln() + 4.8232411251) * r.ln() + 6.6;
642 let result = (num.exp() - 1.0).sqrt();
643 if q < 0.0 {
644 -result
645 } else {
646 result
647 }
648 };
649
650 Ok(result)
651}
652
653#[allow(dead_code)]
655fn t_ppf(p: f64, df: f64) -> StatsResult<f64> {
656 if p <= 0.0 || p >= 1.0 {
657 return Err(StatsError::InvalidArgument(
658 "p must be between 0 and 1".to_string(),
659 ));
660 }
661
662 let z = normal_ppf(p)?;
664
665 if df > 4.0 {
666 let correction = z * z * z / (4.0 * df) + z * z * z * z * z / (96.0 * df * df);
667 Ok(z + correction)
668 } else {
669 Ok(z * (1.0 + (z * z + 1.0) / (4.0 * df)))
671 }
672}