use crate::error::{AprenderError, Result};
use crate::primitives::{Matrix, Vector};
#[derive(Debug, Clone)]
pub struct BayesianLinearRegression {
n_features: usize,
beta_prior_mean: Vec<f32>,
beta_prior_precision: f32,
#[allow(dead_code)]
noise_alpha: f32,
#[allow(dead_code)]
noise_beta: f32,
posterior_mean: Option<Vec<f32>>,
posterior_precision: Option<Vec<Vec<f32>>>,
noise_variance: Option<f32>,
}
impl BayesianLinearRegression {
#[must_use]
pub fn new(n_features: usize) -> Self {
Self {
n_features,
beta_prior_mean: vec![0.0; n_features],
beta_prior_precision: 0.0001, noise_alpha: 0.001, noise_beta: 0.001, posterior_mean: None,
posterior_precision: None,
noise_variance: None,
}
}
pub fn with_prior(
n_features: usize,
beta_prior_mean: Vec<f32>,
beta_prior_precision: f32,
noise_alpha: f32,
noise_beta: f32,
) -> Result<Self> {
if beta_prior_mean.len() != n_features {
return Err(AprenderError::DimensionMismatch {
expected: format!("{n_features} features"),
actual: format!("{} elements in beta_prior_mean", beta_prior_mean.len()),
});
}
if beta_prior_precision <= 0.0 {
return Err(AprenderError::InvalidHyperparameter {
param: "beta_prior_precision".to_string(),
value: beta_prior_precision.to_string(),
constraint: "must be > 0".to_string(),
});
}
if noise_alpha <= 0.0 || noise_beta <= 0.0 {
return Err(AprenderError::InvalidHyperparameter {
param: "noise_alpha or noise_beta".to_string(),
value: format!("α={noise_alpha}, β={noise_beta}"),
constraint: "both must be > 0".to_string(),
});
}
Ok(Self {
n_features,
beta_prior_mean,
beta_prior_precision,
noise_alpha,
noise_beta,
posterior_mean: None,
posterior_precision: None,
noise_variance: None,
})
}
#[must_use]
pub fn n_features(&self) -> usize {
self.n_features
}
#[must_use]
pub fn posterior_mean(&self) -> Option<&[f32]> {
self.posterior_mean.as_deref()
}
#[must_use]
pub fn noise_variance(&self) -> Option<f32> {
self.noise_variance
}
pub fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
let n = x.n_rows();
let p = x.n_cols();
if p != self.n_features {
return Err(AprenderError::DimensionMismatch {
expected: format!("{} features in X", self.n_features),
actual: format!("{p} columns in X"),
});
}
if n != y.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{n} samples in X"),
actual: format!("{} samples in y", y.len()),
});
}
if n < p {
return Err(AprenderError::Other(format!(
"Need at least {p} samples for {p} features, got {n}"
)));
}
let xt = x.transpose();
let xtx = xt.matmul(x).map_err(|e| AprenderError::Other(e.into()))?;
let xty = xt.matvec(y).map_err(|e| AprenderError::Other(e.into()))?;
let beta_ols = xtx
.cholesky_solve(&xty)
.map_err(|e| AprenderError::Other(format!("Cholesky decomposition failed: {e}")))?;
let y_pred = x.matvec(&beta_ols).map_err(|e| {
AprenderError::Other(format!("Matrix-vector multiplication failed: {e}"))
})?;
let mut rss = 0.0_f32;
for i in 0..n {
let residual = y[i] - y_pred[i];
rss += residual * residual;
}
let sigma2 = rss / ((n - p) as f32);
let prior_precision_matrix = Matrix::eye(p).mul_scalar(self.beta_prior_precision);
let data_precision = xtx.mul_scalar(1.0 / sigma2);
let posterior_precision_inv = prior_precision_matrix
.add(&data_precision)
.map_err(|e| AprenderError::Other(format!("Matrix addition failed: {e}")))?;
let mut rhs = Vec::with_capacity(p);
for i in 0..p {
let prior_term = self.beta_prior_mean[i] * self.beta_prior_precision;
let data_term = xty[i] / sigma2;
rhs.push(prior_term + data_term);
}
let rhs_vec = Vector::from_vec(rhs);
let posterior_mean = posterior_precision_inv
.cholesky_solve(&rhs_vec)
.map_err(|e| AprenderError::Other(format!("Posterior mean computation failed: {e}")))?;
self.posterior_mean = Some(posterior_mean.as_slice().to_vec());
self.noise_variance = Some(sigma2);
let precision_data: Vec<Vec<f32>> = posterior_precision_inv
.as_slice()
.chunks(p)
.map(|row: &[f32]| row.to_vec())
.collect();
self.posterior_precision = Some(precision_data);
Ok(())
}
pub fn predict(&self, x_test: &Matrix<f32>) -> Result<Vector<f32>> {
let posterior_mean = self.posterior_mean.as_ref().ok_or_else(|| {
AprenderError::Other("Model not fitted yet. Call fit() first.".into())
})?;
if x_test.n_cols() != self.n_features {
return Err(AprenderError::DimensionMismatch {
expected: format!("{} features", self.n_features),
actual: format!("{} columns in x_test", x_test.n_cols()),
});
}
let beta = Vector::from_slice(posterior_mean);
x_test
.matvec(&beta)
.map_err(|e| AprenderError::Other(format!("Prediction failed: {e}")))
}
pub fn log_likelihood(&self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<f32> {
let posterior_mean = self.posterior_mean.as_ref().ok_or_else(|| {
AprenderError::Other("Model not fitted yet. Call fit() first.".into())
})?;
let sigma2 = self.noise_variance.ok_or_else(|| {
AprenderError::Other("Noise variance not available. Call fit() first.".into())
})?;
let n = x.n_rows() as f32;
if x.n_cols() != self.n_features {
return Err(AprenderError::DimensionMismatch {
expected: format!("{} features", self.n_features),
actual: format!("{} columns in x", x.n_cols()),
});
}
if x.n_rows() != y.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{} samples in x", x.n_rows()),
actual: format!("{} samples in y", y.len()),
});
}
let beta = Vector::from_slice(posterior_mean);
let y_pred = x
.matvec(&beta)
.map_err(|e| AprenderError::Other(format!("Prediction failed: {e}")))?;
let mut rss = 0.0_f32;
for i in 0..y.len() {
let residual = y[i] - y_pred[i];
rss += residual * residual;
}
use std::f32::consts::PI;
let log_lik = -0.5 * n * (2.0 * PI).ln() - 0.5 * n * sigma2.ln() - rss / (2.0 * sigma2);
Ok(log_lik)
}
pub fn bic(&self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<f32> {
let log_lik = self.log_likelihood(x, y)?;
let n = x.n_rows() as f32;
let k = (self.n_features + 1) as f32;
Ok(-2.0 * log_lik + k * n.ln())
}
pub fn aic(&self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<f32> {
let log_lik = self.log_likelihood(x, y)?;
let k = (self.n_features + 1) as f32;
Ok(-2.0 * log_lik + 2.0 * k)
}
}
#[cfg(test)]
#[path = "regression_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_blr_contract.rs"]
mod tests_blr_contract;