Skip to main content

scirs2_stats/bayesian/
regression.rs

1//! Bayesian linear regression models
2//!
3//! This module implements Bayesian approaches to linear regression, providing
4//! posterior distributions over model parameters and predictions.
5
6use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use scirs2_core::validation::*;
9use scirs2_linalg;
10use statrs::statistics::Statistics;
11
12/// Bayesian linear regression with normal-inverse-gamma prior
13///
14/// This implements Bayesian linear regression where:
15/// - Coefficients have a normal prior
16/// - Noise variance has an inverse-gamma prior
17/// - Posterior is analytically tractable
18#[derive(Debug, Clone)]
19pub struct BayesianLinearRegression {
20    /// Prior mean for coefficients
21    pub prior_mean: Array1<f64>,
22    /// Prior precision matrix for coefficients
23    pub prior_precision: Array2<f64>,
24    /// Prior shape parameter for noise variance
25    pub prior_alpha: f64,
26    /// Prior scale parameter for noise variance
27    pub prior_beta: f64,
28    /// Whether to include intercept
29    pub fit_intercept: bool,
30}
31
32/// Result of Bayesian linear regression fit
33#[derive(Debug, Clone)]
34pub struct BayesianRegressionResult {
35    /// Posterior mean of coefficients
36    pub posterior_mean: Array1<f64>,
37    /// Posterior covariance of coefficients
38    pub posterior_covariance: Array2<f64>,
39    /// Posterior shape parameter
40    pub posterior_alpha: f64,
41    /// Posterior scale parameter
42    pub posterior_beta: f64,
43    /// Number of training samples
44    pub n_samples_: usize,
45    /// Number of features
46    pub n_features: usize,
47    /// Training data mean (for centering)
48    pub x_mean: Option<Array1<f64>>,
49    /// Training target mean (for centering)
50    pub y_mean: Option<f64>,
51    /// Log marginal likelihood
52    pub log_marginal_likelihood: f64,
53}
54
55impl BayesianLinearRegression {
56    /// Create a new Bayesian linear regression model
57    pub fn new(n_features: usize, fit_intercept: bool) -> StatsResult<Self> {
58        check_positive(n_features, "n_features")?;
59
60        // Default to weakly informative priors
61        let prior_mean = Array1::zeros(n_features);
62        let prior_precision = Array2::eye(n_features) * 1e-6; // Very small precision (large variance)
63        let prior_alpha = 1e-6; // Very small shape
64        let prior_beta = 1e-6; // Very small scale
65
66        Ok(Self {
67            prior_mean,
68            prior_precision,
69            prior_alpha,
70            prior_beta,
71            fit_intercept,
72        })
73    }
74
75    /// Create with custom priors
76    pub fn with_priors(
77        prior_mean: Array1<f64>,
78        prior_precision: Array2<f64>,
79        prior_alpha: f64,
80        prior_beta: f64,
81        fit_intercept: bool,
82    ) -> StatsResult<Self> {
83        checkarray_finite(&prior_mean, "prior_mean")?;
84        checkarray_finite(&prior_precision, "prior_precision")?;
85        check_positive(prior_alpha, "prior_alpha")?;
86        check_positive(prior_beta, "prior_beta")?;
87
88        if prior_precision.nrows() != prior_mean.len()
89            || prior_precision.ncols() != prior_mean.len()
90        {
91            return Err(StatsError::DimensionMismatch(format!(
92                "prior_precision shape ({}, {}) must match prior_mean length ({})",
93                prior_precision.nrows(),
94                prior_precision.ncols(),
95                prior_mean.len()
96            )));
97        }
98
99        Ok(Self {
100            prior_mean,
101            prior_precision,
102            prior_alpha,
103            prior_beta,
104            fit_intercept,
105        })
106    }
107
108    /// Fit the Bayesian regression model
109    pub fn fit(
110        &self,
111        x: ArrayView2<f64>,
112        y: ArrayView1<f64>,
113    ) -> StatsResult<BayesianRegressionResult> {
114        checkarray_finite(&x, "x")?;
115        checkarray_finite(&y, "y")?;
116        let (n_samples_, n_features) = x.dim();
117
118        if y.len() != n_samples_ {
119            return Err(StatsError::DimensionMismatch(format!(
120                "y length ({}) must match x rows ({})",
121                y.len(),
122                n_samples_
123            )));
124        }
125
126        if n_samples_ < 2 {
127            return Err(StatsError::InvalidArgument(
128                "n_samples_ must be at least 2".to_string(),
129            ));
130        }
131
132        // Center data if fitting intercept
133        let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
134            let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
135            let y_mean = y.mean();
136
137            let mut x_centered = x.to_owned();
138            for mut row in x_centered.rows_mut() {
139                row -= &x_mean;
140            }
141
142            let y_centered = &y.to_owned() - y_mean;
143
144            (x_centered, y_centered, Some(x_mean), Some(y_mean))
145        } else {
146            (x.to_owned(), y.to_owned(), None, None)
147        };
148
149        // Compute posterior parameters
150        let xtx = x_centered.t().dot(&x_centered);
151        let xty = x_centered.t().dot(&y_centered);
152
153        // Posterior precision
154        let posterior_precision = &self.prior_precision + &xtx;
155        let posterior_covariance =
156            scirs2_linalg::inv(&posterior_precision.view(), None).map_err(|e| {
157                StatsError::ComputationError(format!("Failed to invert posterior precision: {}", e))
158            })?;
159
160        // Posterior mean
161        let prior_contribution = self.prior_precision.dot(&self.prior_mean);
162        let data_contribution = &xty;
163        let posterior_mean = posterior_covariance.dot(&(&prior_contribution + data_contribution));
164
165        // Posterior shape and scale for noise variance
166        let posterior_alpha = self.prior_alpha + n_samples_ as f64 / 2.0;
167
168        // Compute residual sum of squares
169        let y_pred = x_centered.dot(&posterior_mean);
170        let residuals = &y_centered - &y_pred;
171        let rss = residuals.dot(&residuals);
172
173        // Prior contribution to scale
174        let prior_quad_form = (&self.prior_mean - &posterior_mean).t().dot(
175            &self
176                .prior_precision
177                .dot(&(&self.prior_mean - &posterior_mean)),
178        );
179
180        let posterior_beta = self.prior_beta + 0.5 * (rss + prior_quad_form);
181
182        // Compute log marginal likelihood
183        let log_marginal = self.compute_log_marginal_likelihood(
184            &x_centered,
185            &y_centered,
186            &posterior_precision,
187            posterior_alpha,
188            posterior_beta,
189        )?;
190
191        Ok(BayesianRegressionResult {
192            posterior_mean,
193            posterior_covariance,
194            posterior_alpha,
195            posterior_beta,
196            n_samples_,
197            n_features,
198            x_mean,
199            y_mean,
200            log_marginal_likelihood: log_marginal,
201        })
202    }
203
204    /// Compute log marginal likelihood
205    fn compute_log_marginal_likelihood(
206        &self,
207        x: &Array2<f64>,
208        _y: &Array1<f64>,
209        posterior_precision: &Array2<f64>,
210        posterior_alpha: f64,
211        posterior_beta: f64,
212    ) -> StatsResult<f64> {
213        let n = x.nrows() as f64;
214        let _p = x.ncols() as f64;
215
216        // Log determinant terms
217        let prior_log_det =
218            scirs2_linalg::det(&self.prior_precision.view(), None).map_err(|e| {
219                StatsError::ComputationError(format!("Failed to compute prior determinant: {}", e))
220            })?;
221
222        let posterior_log_det =
223            scirs2_linalg::det(&posterior_precision.view(), None).map_err(|e| {
224                StatsError::ComputationError(format!(
225                    "Failed to compute posterior determinant: {}",
226                    e
227                ))
228            })?;
229
230        if prior_log_det <= 0.0 || posterior_log_det <= 0.0 {
231            return Err(StatsError::ComputationError(
232                "Precision matrices must be positive definite".to_string(),
233            ));
234        }
235
236        // Gamma function terms
237        let gamma_ratio = gamma_log(posterior_alpha) - gamma_log(self.prior_alpha);
238
239        // Assemble log marginal likelihood
240        let log_ml = -0.5 * n * (2.0 * std::f64::consts::PI).ln() + 0.5 * prior_log_det.ln()
241            - 0.5 * posterior_log_det.ln()
242            + self.prior_alpha * self.prior_beta.ln()
243            - posterior_alpha * posterior_beta.ln()
244            + gamma_ratio;
245
246        Ok(log_ml)
247    }
248
249    /// Make predictions on new data
250    pub fn predict(
251        &self,
252        x: ArrayView2<f64>,
253        result: &BayesianRegressionResult,
254    ) -> StatsResult<BayesianPredictionResult> {
255        checkarray_finite(&x, "x")?;
256        let (n_test, n_features) = x.dim();
257
258        if n_features != result.n_features {
259            return Err(StatsError::DimensionMismatch(format!(
260                "x has {} features, expected {}",
261                n_features, result.n_features
262            )));
263        }
264
265        // Center test data if model was fit with intercept
266        let x_centered = if let Some(ref x_mean) = result.x_mean {
267            let mut x_c = x.to_owned();
268            for mut row in x_c.rows_mut() {
269                row -= x_mean;
270            }
271            x_c
272        } else {
273            x.to_owned()
274        };
275
276        // Predictive mean
277        let y_pred_centered = x_centered.dot(&result.posterior_mean);
278        let y_pred = if let Some(y_mean) = result.y_mean {
279            &y_pred_centered + y_mean
280        } else {
281            y_pred_centered.clone()
282        };
283
284        // Predictive variance
285        let noise_variance = result.posterior_beta / (result.posterior_alpha - 1.0);
286        let mut predictive_variance = Array1::zeros(n_test);
287
288        for i in 0..n_test {
289            let x_row = x_centered.row(i);
290            let model_variance = x_row.dot(&result.posterior_covariance.dot(&x_row));
291            predictive_variance[i] = noise_variance * (1.0 + model_variance);
292        }
293
294        // Degrees of freedom for t-distribution
295        let df = 2.0 * result.posterior_alpha;
296
297        Ok(BayesianPredictionResult {
298            mean: y_pred,
299            variance: predictive_variance,
300            degrees_of_freedom: df,
301            credible_interval: None,
302        })
303    }
304
305    /// Compute credible intervals for predictions
306    pub fn predict_with_credible_interval(
307        &self,
308        x: ArrayView2<f64>,
309        result: &BayesianRegressionResult,
310        confidence: f64,
311    ) -> StatsResult<BayesianPredictionResult> {
312        check_probability(confidence, "confidence")?;
313
314        let mut pred_result = self.predict(x, result)?;
315
316        // Compute credible intervals using t-distribution
317        let alpha = (1.0 - confidence) / 2.0;
318        let df = pred_result.degrees_of_freedom;
319
320        // For simplicity, use normal approximation when df is large
321        let t_critical = if df > 30.0 {
322            // Use normal approximation
323            normal_ppf(1.0 - alpha)?
324        } else {
325            // Use t-distribution (simplified)
326            t_ppf(1.0 - alpha, df)?
327        };
328
329        let mut lower_bounds = Array1::zeros(pred_result.mean.len());
330        let mut upper_bounds = Array1::zeros(pred_result.mean.len());
331
332        for i in 0..pred_result.mean.len() {
333            let std_err = pred_result.variance[i].sqrt();
334            lower_bounds[i] = pred_result.mean[i] - t_critical * std_err;
335            upper_bounds[i] = pred_result.mean[i] + t_critical * std_err;
336        }
337
338        pred_result.credible_interval = Some((lower_bounds, upper_bounds));
339        Ok(pred_result)
340    }
341}
342
343/// Result of Bayesian prediction
344#[derive(Debug, Clone)]
345pub struct BayesianPredictionResult {
346    /// Predictive mean
347    pub mean: Array1<f64>,
348    /// Predictive variance
349    pub variance: Array1<f64>,
350    /// Degrees of freedom for t-distribution
351    pub degrees_of_freedom: f64,
352    /// Credible interval (lower, upper) if computed
353    pub credible_interval: Option<(Array1<f64>, Array1<f64>)>,
354}
355
356/// Automatic Relevance Determination (ARD) Bayesian regression
357///
358/// ARD uses separate precision parameters for each feature to perform
359/// automatic feature selection by driving irrelevant features to zero.
360#[derive(Debug, Clone)]
361pub struct ARDBayesianRegression {
362    /// Maximum number of iterations
363    pub max_iter: usize,
364    /// Convergence tolerance
365    pub tol: f64,
366    /// Initial alpha (precision) parameters
367    pub alpha_init: Option<Array1<f64>>,
368    /// Initial beta (noise precision) parameter
369    pub beta_init: f64,
370    /// Whether to fit intercept
371    pub fit_intercept: bool,
372}
373
374impl Default for ARDBayesianRegression {
375    fn default() -> Self {
376        Self::new()
377    }
378}
379
380impl ARDBayesianRegression {
381    /// Create a new ARD Bayesian regression model
382    pub fn new() -> Self {
383        Self {
384            max_iter: 300,
385            tol: 1e-3,
386            alpha_init: None,
387            beta_init: 1.0,
388            fit_intercept: true,
389        }
390    }
391
392    /// Set maximum iterations
393    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
394        self.max_iter = max_iter;
395        self
396    }
397
398    /// Set convergence tolerance
399    pub fn with_tolerance(mut self, tol: f64) -> Self {
400        self.tol = tol;
401        self
402    }
403
404    /// Fit ARD Bayesian regression using iterative optimization
405    pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView1<f64>) -> StatsResult<ARDRegressionResult> {
406        checkarray_finite(&x, "x")?;
407        checkarray_finite(&y, "y")?;
408        let (n_samples_, n_features) = x.dim();
409
410        if y.len() != n_samples_ {
411            return Err(StatsError::DimensionMismatch(format!(
412                "y length ({}) must match x rows ({})",
413                y.len(),
414                n_samples_
415            )));
416        }
417
418        // Center data if fitting intercept
419        let (x_centered, y_centered, x_mean, y_mean) = if self.fit_intercept {
420            let x_mean = x.mean_axis(Axis(0)).expect("Operation failed");
421            let y_mean = y.mean();
422
423            let mut x_centered = x.to_owned();
424            for mut row in x_centered.rows_mut() {
425                row -= &x_mean;
426            }
427
428            let y_centered = &y.to_owned() - y_mean;
429
430            (x_centered, y_centered, Some(x_mean), Some(y_mean))
431        } else {
432            (x.to_owned(), y.to_owned(), None, None)
433        };
434
435        // Initialize hyperparameters
436        let mut alpha = self
437            .alpha_init
438            .clone()
439            .unwrap_or_else(|| Array1::from_elem(n_features, 1.0));
440        let mut beta = self.beta_init;
441
442        let xtx = x_centered.t().dot(&x_centered);
443        let xty = x_centered.t().dot(&y_centered);
444
445        let mut prev_log_ml = f64::NEG_INFINITY;
446
447        for iteration in 0..self.max_iter {
448            // Update posterior mean and covariance
449            let alpha_diag = Array2::from_diag(&alpha);
450            let precision = &alpha_diag + beta * &xtx;
451
452            let covariance = scirs2_linalg::inv(&precision.view(), None).map_err(|e| {
453                StatsError::ComputationError(format!("Failed to invert precision: {}", e))
454            })?;
455
456            let mean = beta * covariance.dot(&xty);
457
458            // Update alpha (feature precisions)
459            let mut new_alpha = Array1::zeros(n_features);
460            for i in 0..n_features {
461                let gamma_i = 1.0 - alpha[i] * covariance[[i, i]];
462                new_alpha[i] = gamma_i / (mean[i] * mean[i]);
463
464                // Prevent numerical issues
465                if !new_alpha[i].is_finite() || new_alpha[i] < 1e-12 {
466                    new_alpha[i] = 1e-12;
467                }
468            }
469
470            // Update beta (noise precision)
471            let y_pred = x_centered.dot(&mean);
472            let residuals = &y_centered - &y_pred;
473            let rss = residuals.dot(&residuals);
474
475            let _trace_cov = covariance.diag().sum();
476            let new_beta =
477                (n_samples_ as f64 - new_alpha.sum() + alpha.dot(&covariance.diag())) / rss;
478
479            // Check convergence
480            let log_ml = self.compute_ard_log_marginal_likelihood(
481                &x_centered,
482                &y_centered,
483                &new_alpha,
484                new_beta,
485            )?;
486
487            if (log_ml - prev_log_ml).abs() < self.tol {
488                alpha = new_alpha;
489                beta = new_beta;
490                break;
491            }
492
493            alpha = new_alpha;
494            beta = new_beta;
495            prev_log_ml = log_ml;
496
497            if iteration == self.max_iter - 1 {
498                return Err(StatsError::ComputationError(format!(
499                    "ARD failed to converge after {} iterations",
500                    self.max_iter
501                )));
502            }
503        }
504
505        // Final posterior computation
506        let alpha_diag = Array2::from_diag(&alpha);
507        let precision = &alpha_diag + beta * &xtx;
508        let covariance = scirs2_linalg::inv(&precision.view(), None).map_err(|e| {
509            StatsError::ComputationError(format!("Failed to compute final covariance: {}", e))
510        })?;
511        let mean = beta * covariance.dot(&xty);
512
513        Ok(ARDRegressionResult {
514            posterior_mean: mean,
515            posterior_covariance: covariance,
516            alpha,
517            beta,
518            n_samples_,
519            n_features,
520            x_mean,
521            y_mean,
522            log_marginal_likelihood: prev_log_ml,
523        })
524    }
525
526    /// Compute log marginal likelihood for ARD
527    fn compute_ard_log_marginal_likelihood(
528        &self,
529        x: &Array2<f64>,
530        y: &Array1<f64>,
531        alpha: &Array1<f64>,
532        beta: f64,
533    ) -> StatsResult<f64> {
534        let n = x.nrows() as f64;
535        let p = x.ncols() as f64;
536
537        let xtx = x.t().dot(x);
538        let xty = x.t().dot(y);
539
540        let alpha_diag = Array2::from_diag(alpha);
541        let precision = &alpha_diag + beta * &xtx;
542
543        let covariance = scirs2_linalg::inv(&precision.view(), None).map_err(|e| {
544            StatsError::ComputationError(format!("Failed to invert precision for log ML: {}", e))
545        })?;
546
547        let mean = beta * covariance.dot(&xty);
548
549        // Compute log determinant
550        let log_det_precision = scirs2_linalg::det(&precision.view(), None).map_err(|e| {
551            StatsError::ComputationError(format!("Failed to compute determinant: {}", e))
552        })?;
553
554        if log_det_precision <= 0.0 {
555            return Err(StatsError::ComputationError(
556                "Precision matrix must be positive definite".to_string(),
557            ));
558        }
559
560        // Compute quadratic forms
561        let y_pred = x.dot(&mean);
562        let residuals = y - &y_pred;
563        let data_fit = beta * residuals.dot(&residuals);
564        let penalty = alpha
565            .iter()
566            .zip(mean.iter())
567            .map(|(&a, &m)| a * m * m)
568            .sum::<f64>();
569
570        let log_ml = 0.5
571            * (p * alpha.mapv(f64::ln).sum() + n * beta.ln() + log_det_precision.ln()
572                - n * (2.0 * std::f64::consts::PI).ln()
573                - data_fit
574                - penalty);
575
576        Ok(log_ml)
577    }
578}
579
580/// Result of ARD Bayesian regression fit
581#[derive(Debug, Clone)]
582pub struct ARDRegressionResult {
583    /// Posterior mean of coefficients
584    pub posterior_mean: Array1<f64>,
585    /// Posterior covariance of coefficients
586    pub posterior_covariance: Array2<f64>,
587    /// Feature precision parameters (higher = less relevant)
588    pub alpha: Array1<f64>,
589    /// Noise precision parameter
590    pub beta: f64,
591    /// Number of training samples
592    pub n_samples_: usize,
593    /// Number of features
594    pub n_features: usize,
595    /// Training data mean (for centering)
596    pub x_mean: Option<Array1<f64>>,
597    /// Training target mean (for centering)
598    pub y_mean: Option<f64>,
599    /// Log marginal likelihood
600    pub log_marginal_likelihood: f64,
601}
602
603// Helper functions for statistical distributions
604
605/// Log of gamma function (simplified implementation)
606#[allow(dead_code)]
607fn gamma_log(x: f64) -> f64 {
608    // Using Stirling's approximation for simplicity
609    // In practice, you'd use a more accurate implementation
610    if x <= 0.0 {
611        return f64::NEG_INFINITY;
612    }
613
614    if x < 1.0 {
615        return gamma_log(x + 1.0) - x.ln();
616    }
617
618    0.5 * (2.0 * std::f64::consts::PI).ln() + (x - 0.5) * x.ln() - x + 1.0 / (12.0 * x)
619}
620
621/// Normal distribution percent point function (inverse CDF)
622#[allow(dead_code)]
623fn normal_ppf(p: f64) -> StatsResult<f64> {
624    if p <= 0.0 || p >= 1.0 {
625        return Err(StatsError::InvalidArgument(
626            "p must be between 0 and 1".to_string(),
627        ));
628    }
629
630    // Using Box-Muller inspired approximation
631    // In practice, you'd use a more accurate inverse error function
632    let q = p - 0.5;
633    let result = if q.abs() < 0.5 {
634        let r = q * q;
635        let num =
636            (((-25.44106049637) * r + 41.39119773534) * r + (-18.61500062529)) * r + 2.50662823884;
637        let den = (((-7.784894002430) * r + 14.38718147627) * r + (-3.47396220392)) * r + 1.0;
638        q * num / den
639    } else {
640        let r = if q < 0.0 { p } else { 1.0 - p };
641        let num = (2.01033439929 * r.ln() + 4.8232411251) * r.ln() + 6.6;
642        let result = (num.exp() - 1.0).sqrt();
643        if q < 0.0 {
644            -result
645        } else {
646            result
647        }
648    };
649
650    Ok(result)
651}
652
653/// Student's t distribution percent point function (simplified)
654#[allow(dead_code)]
655fn t_ppf(p: f64, df: f64) -> StatsResult<f64> {
656    if p <= 0.0 || p >= 1.0 {
657        return Err(StatsError::InvalidArgument(
658            "p must be between 0 and 1".to_string(),
659        ));
660    }
661
662    // Simplified approximation - in practice use proper t-distribution
663    let z = normal_ppf(p)?;
664
665    if df > 4.0 {
666        let correction = z * z * z / (4.0 * df) + z * z * z * z * z / (96.0 * df * df);
667        Ok(z + correction)
668    } else {
669        // Very rough approximation for small df
670        Ok(z * (1.0 + (z * z + 1.0) / (4.0 * df)))
671    }
672}