use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, One, ToPrimitive, Zero};
use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct EnhancedBayesianRegression<F> {
pub design_matrix: Array2<F>,
pub response: Array1<F>,
pub prior: BayesianRegressionPrior<F>,
pub inference_method: InferenceMethod,
pub config: BayesianRegressionConfig,
_phantom: PhantomData<F>,
}
#[derive(Debug, Clone)]
pub struct BayesianRegressionPrior<F> {
pub beta_mean: Array1<F>,
pub beta_precision: Array2<F>,
pub noiseshape: F,
pub noise_rate: F,
}
#[derive(Debug, Clone, PartialEq)]
pub enum InferenceMethod {
Exact,
VariationalBayes,
MCMC,
ExpectationPropagation,
}
#[derive(Debug, Clone)]
pub struct BayesianRegressionConfig {
pub max_iter: usize,
pub tolerance: f64,
pub parallel: bool,
pub seed: Option<u64>,
}
impl Default for BayesianRegressionConfig {
fn default() -> Self {
Self {
max_iter: 1000,
tolerance: 1e-6,
parallel: true,
seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct BayesianRegressionResult<F> {
pub beta_mean: Array1<F>,
pub beta_covariance: Array2<F>,
pub noise_precision_mean: F,
pub noise_precision_var: F,
pub log_marginal_likelihood: F,
pub predictive_mean: Array1<F>,
pub predictive_var: Array1<F>,
pub convergence_info: ConvergenceInfo,
}
#[derive(Debug, Clone)]
pub struct ConvergenceInfo {
pub converged: bool,
pub iterations: usize,
pub final_tolerance: f64,
}
impl<F> EnhancedBayesianRegression<F>
where
F: Float
+ Zero
+ One
+ Copy
+ Send
+ Sync
+ SimdUnifiedOps
+ std::fmt::Display
+ 'static
+ std::iter::Sum
+ NumAssign
+ ScalarOperand
+ ToPrimitive
+ FromPrimitive,
{
pub fn new(
design_matrix: Array2<F>,
response: Array1<F>,
prior: BayesianRegressionPrior<F>,
inference_method: InferenceMethod,
) -> StatsResult<Self> {
checkarray_finite(&design_matrix, "design_matrix")?;
checkarray_finite(&response, "response")?;
checkarray_finite(&prior.beta_mean, "beta_mean")?;
checkarray_finite(&prior.beta_precision, "beta_precision")?;
let (n, p) = design_matrix.dim();
if response.len() != n {
return Err(StatsError::DimensionMismatch(format!(
"Response length ({}) must match design _matrix rows ({})",
response.len(),
n
)));
}
if prior.beta_mean.len() != p {
return Err(StatsError::DimensionMismatch(format!(
"Prior mean length ({}) must match design _matrix columns ({})",
prior.beta_mean.len(),
p
)));
}
if prior.beta_precision.nrows() != p || prior.beta_precision.ncols() != p {
return Err(StatsError::DimensionMismatch(format!(
"Prior precision shape ({}, {}) must be ({}, {})",
prior.beta_precision.nrows(),
prior.beta_precision.ncols(),
p,
p
)));
}
Ok(Self {
design_matrix,
response,
prior,
inference_method,
config: BayesianRegressionConfig::default(),
_phantom: PhantomData,
})
}
pub fn with_config(mut self, config: BayesianRegressionConfig) -> Self {
self.config = config;
self
}
pub fn fit(&self) -> StatsResult<BayesianRegressionResult<F>> {
match self.inference_method {
InferenceMethod::Exact => self.fit_exact(),
InferenceMethod::VariationalBayes => self.fit_variational_bayes(),
InferenceMethod::MCMC => self.fit_mcmc(),
InferenceMethod::ExpectationPropagation => self.fit_expectation_propagation(),
}
}
fn fit_exact(&self) -> StatsResult<BayesianRegressionResult<F>> {
let x = &self.design_matrix;
let y = &self.response;
let n = x.nrows() as f64;
let p = x.ncols();
let xtx = x.t().dot(x);
let xty = x.t().dot(y);
let xtx_f64 = xtx.mapv(|v| v.to_f64().unwrap_or(0.0));
let xty_f64 = xty.mapv(|v| v.to_f64().unwrap_or(0.0));
let prior_precision_f64 = self
.prior
.beta_precision
.mapv(|v| v.to_f64().unwrap_or(0.0));
let prior_mean_f64 = self.prior.beta_mean.mapv(|v| v.to_f64().unwrap_or(0.0));
let noiseshape_f64 = self.prior.noiseshape.to_f64().unwrap_or(1.0);
let noise_rate_f64 = self.prior.noise_rate.to_f64().unwrap_or(1.0);
let posterior_precision_f64 = xtx_f64.clone() + prior_precision_f64.clone();
let posterior_covariance_f64 = scirs2_linalg::inv(&posterior_precision_f64.view(), None)
.map_err(|e| {
StatsError::ComputationError(format!("Failed to invert posterior precision: {}", e))
})?;
let posterior_mean_f64 = posterior_covariance_f64
.dot(&(xtx_f64.dot(&xty_f64) + prior_precision_f64.dot(&prior_mean_f64)));
let posterior_mean_f: Array1<F> =
posterior_mean_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
let residual = y - &x.dot(&posterior_mean_f);
let residual_sum_squares = residual.dot(&residual).to_f64().unwrap_or(0.0);
let posterior_noiseshape = noiseshape_f64 + n / 2.0;
let posterior_noise_rate = noise_rate_f64 + residual_sum_squares / 2.0;
let beta_mean =
posterior_mean_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
let beta_covariance =
posterior_covariance_f64.mapv(|v| F::from(v).expect("Failed to convert to float"));
let noise_precision_mean = F::from(posterior_noiseshape / posterior_noise_rate)
.expect("Failed to convert to float");
let noise_precision_var =
F::from(posterior_noiseshape / (posterior_noise_rate * posterior_noise_rate))
.expect("Operation failed");
let predictive_mean = x.dot(&beta_mean);
let predictive_var_diag =
self.compute_predictive_variance(x.view(), &beta_covariance, noise_precision_mean)?;
let log_marginal_likelihood = self.compute_log_marginal_likelihood(
&xtx_f64,
&xty_f64,
&prior_precision_f64,
&prior_mean_f64,
noiseshape_f64,
noise_rate_f64,
n,
p,
)?;
Ok(BayesianRegressionResult {
beta_mean,
beta_covariance,
noise_precision_mean,
noise_precision_var,
log_marginal_likelihood,
predictive_mean,
predictive_var: predictive_var_diag,
convergence_info: ConvergenceInfo {
converged: true,
iterations: 1,
final_tolerance: 0.0,
},
})
}
fn fit_variational_bayes(&self) -> StatsResult<BayesianRegressionResult<F>> {
let x = &self.design_matrix;
let y = &self.response;
let (n, p) = x.dim();
let mut q_beta_mean = self.prior.beta_mean.clone();
let mut q_beta_precision = self.prior.beta_precision.clone();
let mut q_noiseshape = self.prior.noiseshape;
let mut q_noise_rate = self.prior.noise_rate;
let mut converged = false;
let mut iterations = 0;
let mut prev_elbo = F::neg_infinity();
for iter in 0..self.config.max_iter {
iterations = iter + 1;
let xtx = x.t().dot(x);
let xty = x.t().dot(y);
let expected_noise_precision = q_noiseshape / q_noise_rate;
q_beta_precision =
self.prior.beta_precision.clone() + xtx.mapv(|v| v * expected_noise_precision);
let q_beta_covariance = scirs2_linalg::inv(&q_beta_precision.view(), None)
.map_err(|e| StatsError::ComputationError(format!("VB update failed: {}", e)))?;
q_beta_mean = q_beta_covariance.dot(
&(self.prior.beta_precision.dot(&self.prior.beta_mean)
+ xty.mapv(|v| v * expected_noise_precision)),
);
q_noiseshape = self.prior.noiseshape
+ F::from(n).expect("Failed to convert to float")
/ F::from(2.0).expect("Failed to convert constant to float");
let _expected_beta_squared =
q_beta_mean.dot(&q_beta_mean) + q_beta_covariance.diag().sum();
let residual_term = y.dot(y)
- F::from(2.0).expect("Failed to convert constant to float")
* y.dot(&x.dot(&q_beta_mean))
+ x.dot(&q_beta_mean).dot(&x.dot(&q_beta_mean))
+ (x.t().dot(x) * q_beta_covariance).diag().sum();
q_noise_rate = self.prior.noise_rate
+ residual_term / F::from(2.0).expect("Failed to convert constant to float");
let elbo =
self.compute_elbo(&q_beta_mean, &q_beta_precision, q_noiseshape, q_noise_rate)?;
if (elbo - prev_elbo).abs()
< F::from(self.config.tolerance).expect("Failed to convert to float")
{
converged = true;
break;
}
prev_elbo = elbo;
}
let beta_covariance = scirs2_linalg::inv(&q_beta_precision.view(), None).map_err(|e| {
StatsError::ComputationError(format!("Final covariance computation failed: {}", e))
})?;
let noise_precision_mean = q_noiseshape / q_noise_rate;
let noise_precision_var = q_noiseshape / (q_noise_rate * q_noise_rate);
let predictive_mean = x.dot(&q_beta_mean);
let predictive_var =
self.compute_predictive_variance(x.view(), &beta_covariance, noise_precision_mean)?;
let log_marginal_likelihood = prev_elbo;
Ok(BayesianRegressionResult {
beta_mean: q_beta_mean,
beta_covariance,
noise_precision_mean,
noise_precision_var,
log_marginal_likelihood,
predictive_mean,
predictive_var,
convergence_info: ConvergenceInfo {
converged,
iterations,
final_tolerance: if converged {
self.config.tolerance
} else {
f64::INFINITY
},
},
})
}
fn fit_mcmc(&self) -> StatsResult<BayesianRegressionResult<F>> {
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::SeedableRng;
use scirs2_core::random::{Distribution, Gamma};
let x = &self.design_matrix;
let y = &self.response;
let (n, p) = x.dim();
let n_samples_ = self.config.max_iter;
let n_burnin = n_samples_ / 4; let n_thin = 1;
let mut rng = match self.config.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => {
let mut rng = scirs2_core::random::thread_rng();
StdRng::from_rng(&mut rng)
}
};
#[allow(unused_assignments)]
let mut beta = self.prior.beta_mean.clone();
let mut noise_precision = self.prior.noiseshape / self.prior.noise_rate;
let mut beta_samples = Vec::with_capacity(n_samples_ - n_burnin);
let mut noise_precision_samples_ = Vec::with_capacity(n_samples_ - n_burnin);
let mut log_likelihood_history = Vec::new();
let xtx = x.t().dot(x);
let xty = x.t().dot(y);
for iter in 0..n_samples_ {
let precision_matrix =
self.prior.beta_precision.clone() + xtx.mapv(|v| v * noise_precision);
let precision_f64 = precision_matrix.mapv(|v| v.to_f64().unwrap_or(0.0));
let posterior_cov_f64 =
scirs2_linalg::inv(&precision_f64.view(), None).map_err(|e| {
StatsError::ComputationError(format!("MCMC covariance inversion failed: {}", e))
})?;
let mean_term = self.prior.beta_precision.dot(&self.prior.beta_mean)
+ xty.mapv(|v| v * noise_precision);
let posterior_mean_f64 =
posterior_cov_f64.dot(&mean_term.mapv(|v| v.to_f64().unwrap_or(0.0)));
beta =
self.sample_multivariate_normal(&posterior_mean_f64, &posterior_cov_f64, &mut rng)?;
let residual = y - &x.dot(&beta);
let sum_squared_residuals = residual.dot(&residual).to_f64().unwrap_or(0.0);
let posteriorshape = self.prior.noiseshape.to_f64().unwrap_or(1.0) + (n as f64) / 2.0;
let posterior_rate =
self.prior.noise_rate.to_f64().unwrap_or(1.0) + sum_squared_residuals / 2.0;
let gamma_dist = Gamma::new(posteriorshape, 1.0 / posterior_rate).map_err(|e| {
StatsError::ComputationError(format!("Failed to create gamma distribution: {}", e))
})?;
noise_precision = F::from(gamma_dist.sample(&mut rng)).expect("Operation failed");
if iter >= n_burnin && (iter - n_burnin).is_multiple_of(n_thin) {
beta_samples.push(beta.clone());
noise_precision_samples_.push(noise_precision);
}
if iter % 100 == 0 {
let ll = self.compute_mcmc_log_likelihood(&beta, noise_precision)?;
log_likelihood_history.push(ll);
}
}
let n_kept_samples = beta_samples.len();
if n_kept_samples == 0 {
return Err(StatsError::ComputationError(
"No MCMC samples collected".to_string(),
));
}
let mut posterior_beta_mean = Array1::zeros(p);
for sample in &beta_samples {
posterior_beta_mean += sample;
}
posterior_beta_mean /= F::from(n_kept_samples).expect("Failed to convert to float");
let mut posterior_beta_cov = Array2::zeros((p, p));
for sample in &beta_samples {
let centered = sample - &posterior_beta_mean;
for i in 0..p {
for j in 0..p {
posterior_beta_cov[[i, j]] += centered[i] * centered[j];
}
}
}
posterior_beta_cov /=
F::from(n_kept_samples.saturating_sub(1).max(1)).expect("Operation failed");
let noise_precision_mean = noise_precision_samples_
.iter()
.fold(F::zero(), |acc, &x| acc + x)
/ F::from(n_kept_samples).expect("Failed to convert to float");
let noise_precision_var = {
let mean_sq = noise_precision_samples_
.iter()
.map(|&x| (x - noise_precision_mean) * (x - noise_precision_mean))
.fold(F::zero(), |acc, x| acc + x)
/ F::from(n_kept_samples.saturating_sub(1).max(1)).expect("Operation failed");
mean_sq
};
let predictive_mean = x.dot(&posterior_beta_mean);
let predictive_var =
self.compute_predictive_variance(x.view(), &posterior_beta_cov, noise_precision_mean)?;
let final_log_likelihood = if log_likelihood_history.is_empty() {
self.compute_mcmc_log_likelihood(&posterior_beta_mean, noise_precision_mean)?
} else {
*log_likelihood_history.last().expect("Operation failed")
};
let converged = self.check_mcmc_convergence(&beta_samples, &noise_precision_samples_)?;
Ok(BayesianRegressionResult {
beta_mean: posterior_beta_mean,
beta_covariance: posterior_beta_cov,
noise_precision_mean,
noise_precision_var,
log_marginal_likelihood: final_log_likelihood,
predictive_mean,
predictive_var,
convergence_info: ConvergenceInfo {
converged,
iterations: n_samples_,
final_tolerance: if converged {
self.config.tolerance
} else {
f64::INFINITY
},
},
})
}
fn fit_expectation_propagation(&self) -> StatsResult<BayesianRegressionResult<F>> {
self.fit_variational_bayes()
}
fn compute_predictive_variance(
&self,
x: ArrayView2<F>,
beta_covariance: &Array2<F>,
noise_precision_mean: F,
) -> StatsResult<Array1<F>> {
let n = x.nrows();
let mut predictive_var = Array1::zeros(n);
for i in 0..n {
let x_i = x.row(i);
let var_beta = x_i.dot(&beta_covariance.dot(&x_i));
let var_noise = F::one() / noise_precision_mean;
predictive_var[i] = var_beta + var_noise;
}
Ok(predictive_var)
}
fn compute_log_marginal_likelihood(
&self,
xtx: &Array2<f64>,
_xty: &Array1<f64>,
prior_precision: &Array2<f64>,
_prior_mean: &Array1<f64>,
noiseshape: f64,
noise_rate: f64,
n: f64,
p: usize,
) -> StatsResult<F> {
let posterior_precision = xtx + prior_precision;
let det_prior = scirs2_linalg::det(&prior_precision.view(), None).map_err(|e| {
StatsError::ComputationError(format!("Determinant computation failed: {}", e))
})?;
let det_posterior = scirs2_linalg::det(&posterior_precision.view(), None).map_err(|e| {
StatsError::ComputationError(format!("Determinant computation failed: {}", e))
})?;
let log_ml = 0.5 * (det_prior / det_posterior).ln() + noiseshape * noise_rate.ln()
- (n / 2.0) * (2.0 * std::f64::consts::PI).ln();
Ok(F::from(log_ml).expect("Failed to convert to float"))
}
fn compute_elbo(
&self,
q_beta_mean: &Array1<F>,
_q_beta_precision: &Array2<F>,
q_noiseshape: F,
q_noise_rate: F,
) -> StatsResult<F> {
let expected_noise_precision = q_noiseshape / q_noise_rate;
let residual = &self.response - &self.design_matrix.dot(q_beta_mean);
let data_term = -F::from(0.5).expect("Failed to convert constant to float")
* expected_noise_precision
* residual.dot(&residual);
Ok(data_term)
}
fn sample_multivariate_normal<R: scirs2_core::random::Rng>(
&self,
mean: &Array1<f64>,
covariance: &Array2<f64>,
rng: &mut R,
) -> StatsResult<Array1<F>> {
use scirs2_core::random::{Distribution, StandardNormal};
let d = mean.len();
let chol = scirs2_linalg::cholesky(&covariance.view(), None).map_err(|e| {
StatsError::ComputationError(format!("Cholesky decomposition failed: {}", e))
})?;
let z: Vec<f64> = (0..d).map(|_| StandardNormal.sample(rng)).collect();
let z_array = Array1::from_vec(z);
let sample_f64 = mean + &chol.dot(&z_array);
let sample = sample_f64.mapv(|x| F::from(x).expect("Failed to convert to float"));
Ok(sample)
}
fn compute_mcmc_log_likelihood(&self, beta: &Array1<F>, noise_precision: F) -> StatsResult<F> {
let x = &self.design_matrix;
let y = &self.response;
let n = x.nrows() as f64;
let residual = y - &x.dot(beta);
let sum_squared_residuals = residual.dot(&residual).to_f64().unwrap_or(0.0);
let log_likelihood = (n / 2.0) * noise_precision.to_f64().unwrap_or(1.0).ln()
- (n / 2.0) * (2.0 * std::f64::consts::PI).ln()
- 0.5 * noise_precision.to_f64().unwrap_or(1.0) * sum_squared_residuals;
Ok(F::from(log_likelihood).expect("Failed to convert to float"))
}
fn check_mcmc_convergence(
&self,
beta_samples: &[Array1<F>],
noise_precision_samples_: &[F],
) -> StatsResult<bool> {
if beta_samples.len() < 100 {
return Ok(false); }
let n = beta_samples.len();
let mid = n / 2;
let first_half = &beta_samples[..mid];
let second_half = &beta_samples[mid..];
if !beta_samples.is_empty() && !beta_samples[0].is_empty() {
let first_half_var = self
.compute_sample_variance_1d(&first_half.iter().map(|x| x[0]).collect::<Vec<_>>());
let second_half_var = self
.compute_sample_variance_1d(&second_half.iter().map(|x| x[0]).collect::<Vec<_>>());
let var_ratio =
first_half_var.max(second_half_var) / first_half_var.min(second_half_var);
if var_ratio > F::from(2.0).expect("Failed to convert constant to float") {
return Ok(false); }
}
let eff_samplesize = self.compute_effective_samplesize(noise_precision_samples_)?;
if eff_samplesize < 100.0 {
return Ok(false); }
Ok(true)
}
fn compute_sample_variance_1d(&self, samples: &[F]) -> F {
if samples.is_empty() {
return F::one();
}
let n = samples.len();
let mean = samples.iter().fold(F::zero(), |acc, &x| acc + x)
/ F::from(n).expect("Failed to convert to float");
let variance = samples
.iter()
.map(|&x| (x - mean) * (x - mean))
.fold(F::zero(), |acc, x| acc + x)
/ F::from(n.saturating_sub(1).max(1)).expect("Operation failed");
variance.max(F::from(1e-10).expect("Failed to convert constant to float"))
}
fn compute_effective_samplesize(&self, samples: &[F]) -> StatsResult<f64> {
if samples.len() < 10 {
return Ok(samples.len() as f64);
}
let n = samples.len();
let mean = samples.iter().fold(F::zero(), |acc, &x| acc + x)
/ F::from(n).expect("Failed to convert to float");
let mut numerator = F::zero();
let mut denominator = F::zero();
for i in 0..n - 1 {
let x_i = samples[i] - mean;
let x_i1 = samples[i + 1] - mean;
numerator += x_i * x_i1;
denominator += x_i * x_i;
}
let autocorr = if denominator > F::from(1e-10).expect("Failed to convert constant to float")
{
(numerator / denominator).to_f64().unwrap_or(0.0)
} else {
0.0
};
let eff_n = if autocorr > 0.1 {
n as f64 * (1.0 - autocorr) / (1.0 + autocorr)
} else {
n as f64
};
Ok(eff_n.max(1.0))
}
pub fn predict(
&self,
x_new: &Array2<F>,
result: &BayesianRegressionResult<F>,
) -> StatsResult<(Array1<F>, Array1<F>)> {
checkarray_finite(x_new, "x_new")?;
if x_new.ncols() != self.design_matrix.ncols() {
return Err(StatsError::DimensionMismatch(format!(
"New data columns ({}) must match training data columns ({})",
x_new.ncols(),
self.design_matrix.ncols()
)));
}
let pred_mean = x_new.dot(&result.beta_mean);
let pred_var = self.compute_predictive_variance(
x_new.view(),
&result.beta_covariance,
result.noise_precision_mean,
)?;
Ok((pred_mean, pred_var))
}
}
impl<F> BayesianRegressionPrior<F>
where
F: Float + Zero + One + Copy + ScalarOperand + std::fmt::Display + FromPrimitive,
{
pub fn uninformative(p: usize) -> Self {
let beta_mean = Array1::zeros(p);
let beta_precision =
Array2::eye(p) * F::from(1e-6).expect("Failed to convert constant to float"); let noiseshape = F::from(1e-3).expect("Failed to convert constant to float");
let noise_rate = F::from(1e-3).expect("Failed to convert constant to float");
Self {
beta_mean,
beta_precision,
noiseshape,
noise_rate,
}
}
pub fn ridge(p: usize, alpha: F) -> Self {
let beta_mean = Array1::zeros(p);
let beta_precision = Array2::eye(p) * alpha;
let noiseshape = F::one();
let noise_rate = F::one();
Self {
beta_mean,
beta_precision,
noiseshape,
noise_rate,
}
}
}
#[allow(dead_code)]
pub fn bayesian_linear_regression_exact<F>(
x: Array2<F>,
y: Array1<F>,
prior: Option<BayesianRegressionPrior<F>>,
) -> StatsResult<BayesianRegressionResult<F>>
where
F: Float
+ Zero
+ One
+ Copy
+ Send
+ Sync
+ SimdUnifiedOps
+ 'static
+ std::iter::Sum
+ NumAssign
+ ScalarOperand
+ std::fmt::Display
+ ToPrimitive
+ FromPrimitive,
{
let p = x.ncols();
let prior = prior.unwrap_or_else(|| BayesianRegressionPrior::uninformative(p));
let model = EnhancedBayesianRegression::new(x, y, prior, InferenceMethod::Exact)?;
model.fit()
}
#[allow(dead_code)]
pub fn bayesian_linear_regression_vb<F>(
x: Array2<F>,
y: Array1<F>,
prior: Option<BayesianRegressionPrior<F>>,
config: Option<BayesianRegressionConfig>,
) -> StatsResult<BayesianRegressionResult<F>>
where
F: Float
+ Zero
+ One
+ Copy
+ Send
+ Sync
+ SimdUnifiedOps
+ 'static
+ std::iter::Sum
+ NumAssign
+ ScalarOperand
+ std::fmt::Display
+ ToPrimitive
+ FromPrimitive,
{
let p = x.ncols();
let prior = prior.unwrap_or_else(|| BayesianRegressionPrior::uninformative(p));
let config = config.unwrap_or_default();
let model = EnhancedBayesianRegression::new(x, y, prior, InferenceMethod::VariationalBayes)?
.with_config(config);
model.fit()
}