use std::f64::consts::PI;
use faer::linalg::{matmul, solvers::Solve};
use faer::{Accum, Mat, Par, Side};
use crate::gaussian::cholesky_logdet;
use crate::{BLRError, Gaussian};
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct BLRPrior {
pub mean: Vec<f64>,
pub cov: Vec<f64>,
pub alphas: Vec<f64>,
}
impl BLRPrior {
pub fn validate(&self) -> Result<(), BLRError> {
let d = self.mean.len();
if self.alphas.len() != d {
return Err(BLRError::DimMismatch {
expected: d,
got: self.alphas.len(),
});
}
if self.cov.len() != d * d {
return Err(BLRError::DimMismatch {
expected: d * d,
got: self.cov.len(),
});
}
if d == 0 {
return Err(BLRError::DimMismatch {
expected: 1,
got: 0,
});
}
let cov_mat = Mat::<f64>::from_fn(d, d, |i, j| self.cov[i * d + j]);
cov_mat
.llt(Side::Lower)
.map_err(|_| BLRError::SingularMatrix)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ArdConfig {
pub alpha_init: f64,
pub beta_init: f64,
pub max_iter: usize,
pub tol: f64,
pub update_beta: bool,
}
impl Default for ArdConfig {
fn default() -> Self {
Self {
alpha_init: 1.0,
beta_init: 1.0,
max_iter: 100,
tol: 1e-5,
update_beta: true,
}
}
}
pub struct PredictiveMarginals {
pub mean: Vec<f64>,
pub aleatoric_std: f64,
pub epistemic_std: Vec<f64>,
pub total_std: Vec<f64>,
}
pub struct FittedArd {
pub posterior: Gaussian,
pub alpha: Vec<f64>,
pub beta: f64,
pub log_evidences: Vec<f64>,
pub n_samples: usize,
}
impl FittedArd {
pub fn predict(
&self,
phi_test: &[f64],
n_test: usize,
n_features: usize,
) -> PredictiveMarginals {
let d = n_features;
let sigma_mat = Mat::<f64>::from_fn(d, d, |i, j| self.posterior.cov[i * d + j]);
let mu_col = Mat::<f64>::from_fn(d, 1, |i, _| self.posterior.mean[i]);
let aleatoric_var = 1.0 / self.beta;
let aleatoric_std = aleatoric_var.sqrt();
let mut mean = Vec::with_capacity(n_test);
let mut epistemic_std = Vec::with_capacity(n_test);
let mut total_std = Vec::with_capacity(n_test);
for i in 0..n_test {
let phi_row = Mat::<f64>::from_fn(1, d, |_, j| phi_test[i * d + j]);
let mut m_mat = Mat::<f64>::zeros(1, 1);
matmul::matmul(
m_mat.as_mut(),
Accum::Replace,
phi_row.as_ref(),
mu_col.as_ref(),
1.0_f64,
Par::Seq,
);
mean.push(m_mat[(0, 0)]);
let mut sigma_phi_t = Mat::<f64>::zeros(d, 1);
matmul::matmul(
sigma_phi_t.as_mut(),
Accum::Replace,
sigma_mat.as_ref(),
phi_row.as_ref().transpose(),
1.0_f64,
Par::Seq,
);
let mut ep_var_mat = Mat::<f64>::zeros(1, 1);
matmul::matmul(
ep_var_mat.as_mut(),
Accum::Replace,
phi_row.as_ref(),
sigma_phi_t.as_ref(),
1.0_f64,
Par::Seq,
);
let ep_var = ep_var_mat[(0, 0)].max(0.0);
epistemic_std.push(ep_var.sqrt());
total_std.push((aleatoric_var + ep_var).sqrt());
}
PredictiveMarginals {
mean,
aleatoric_std,
epistemic_std,
total_std,
}
}
pub fn predict_gaussian(
&self,
phi_test: &[f64],
n_test: usize,
n_features: usize,
) -> Result<Gaussian, BLRError> {
let d = n_features;
let m = n_test;
let phi_mat = Mat::<f64>::from_fn(m, d, |i, j| phi_test[i * d + j]);
let sigma_mat = Mat::<f64>::from_fn(d, d, |i, j| self.posterior.cov[i * d + j]);
let mu_col = Mat::<f64>::from_fn(d, 1, |i, _| self.posterior.mean[i]);
let mut pred_mean_mat = Mat::<f64>::zeros(m, 1);
matmul::matmul(
pred_mean_mat.as_mut(),
Accum::Replace,
phi_mat.as_ref(),
mu_col.as_ref(),
1.0_f64,
Par::Seq,
);
let mut tmp = Mat::<f64>::zeros(m, d);
matmul::matmul(
tmp.as_mut(),
Accum::Replace,
phi_mat.as_ref(),
sigma_mat.as_ref(),
1.0_f64,
Par::Seq,
);
let mut pred_cov = Mat::<f64>::zeros(m, m);
matmul::matmul(
pred_cov.as_mut(),
Accum::Replace,
tmp.as_ref(),
phi_mat.as_ref().transpose(),
1.0_f64,
Par::Seq,
);
let noise_var = 1.0 / self.beta;
for i in 0..m {
pred_cov[(i, i)] += noise_var + 1e-9; }
let pred_cov_ref = pred_cov.as_ref();
let pred_mean_vec: Vec<f64> = (0..m).map(|i| pred_mean_mat[(i, 0)]).collect();
let pred_cov_vec: Vec<f64> = (0..m)
.flat_map(|i| (0..m).map(move |j| pred_cov_ref[(i, j)]))
.collect();
Gaussian::new(pred_mean_vec, pred_cov_vec)
}
pub fn relevance(&self) -> Vec<f64> {
self.alpha.iter().map(|a| 1.0 / a).collect()
}
pub fn relevant_features(&self, threshold: Option<f64>) -> Vec<bool> {
let t = threshold.unwrap_or_else(|| {
let ln_mean = self.alpha.iter().map(|a| a.ln()).sum::<f64>() / self.alpha.len() as f64;
ln_mean.exp()
});
self.alpha.iter().map(|a| *a < t).collect()
}
pub fn noise_std(&self) -> f64 {
1.0 / self.beta.sqrt()
}
pub fn log_marginal_likelihood(&self) -> f64 {
*self.log_evidences.last().unwrap_or(&f64::NEG_INFINITY)
}
pub fn noise_precision(&self) -> f64 {
self.beta
}
pub fn posterior_covariance(&self) -> &[f64] {
&self.posterior.cov
}
pub fn sample_count(&self) -> usize {
self.n_samples
}
pub fn posterior_std(&self, phi_test: &[f64], n_test: usize, n_features: usize) -> Vec<f64> {
let d = n_features;
let sigma_cov = &self.posterior.cov;
let noise_var = 1.0 / self.beta.max(1e-10);
(0..n_test)
.map(|i| {
let phi_i = &phi_test[i * d..(i + 1) * d];
let mut sigma_phi = vec![0.0_f64; d];
for row in 0..d {
for col in 0..d {
sigma_phi[row] += sigma_cov[row * d + col] * phi_i[col];
}
}
let epistemic: f64 = phi_i.iter().zip(sigma_phi.iter()).map(|(a, b)| a * b).sum();
(noise_var + epistemic.max(0.0)).sqrt()
})
.collect()
}
pub fn posterior_std_grid(
&self,
input_range: (f64, f64),
resolution: usize,
feature_fn: &dyn Fn(f64) -> Vec<f64>,
) -> (Vec<f64>, Vec<f64>) {
let d_sq = self.posterior.cov.len();
let d = (d_sq as f64).sqrt() as usize;
let resolution = resolution.max(2);
let step = (input_range.1 - input_range.0) / (resolution - 1) as f64;
let grid: Vec<f64> = (0..resolution)
.map(|k| input_range.0 + k as f64 * step)
.collect();
let mut phi_grid = Vec::with_capacity(resolution * d);
for &x in &grid {
let feats = feature_fn(x);
let actual = feats.len().min(d);
phi_grid.extend_from_slice(&feats[..actual]);
if actual < d {
phi_grid.extend(std::iter::repeat(0.0).take(d - actual));
}
}
let stds = self.posterior_std(&phi_grid, resolution, d);
(grid, stds)
}
}
fn log_evidence(
n: usize,
d: usize,
alpha: &[f64],
beta: f64,
mu: &[f64],
logdet_sigma_inv: f64,
residual_sq: f64,
) -> f64 {
let log_alpha_sum: f64 = alpha.iter().map(|a| a.ln()).sum();
let mu_lambda_mu: f64 = alpha.iter().zip(mu.iter()).map(|(a, m)| a * m * m).sum();
0.5 * (log_alpha_sum + (n as f64) * beta.ln()
- logdet_sigma_inv
- beta * residual_sq
- mu_lambda_mu
+ (d as f64) * (2.0 * PI).ln())
- 0.5 * (n as f64) * (2.0 * PI).ln()
}
pub fn fit(
phi: &[f64],
y: &[f64],
n: usize,
d: usize,
config: &ArdConfig,
) -> Result<FittedArd, BLRError> {
if phi.len() != n * d {
return Err(BLRError::DimMismatch {
expected: n * d,
got: phi.len(),
});
}
if y.len() != n {
return Err(BLRError::DimMismatch {
expected: n,
got: y.len(),
});
}
let phi_mat = Mat::<f64>::from_fn(n, d, |i, j| phi[i * d + j]);
let y_mat = Mat::<f64>::from_fn(n, 1, |i, _| y[i]);
let mut phi_t_phi = Mat::<f64>::zeros(d, d);
matmul::matmul(
phi_t_phi.as_mut(),
Accum::Replace,
phi_mat.as_ref().transpose(),
phi_mat.as_ref(),
1.0_f64,
Par::Seq,
);
let mut phi_t_y = Mat::<f64>::zeros(d, 1);
matmul::matmul(
phi_t_y.as_mut(),
Accum::Replace,
phi_mat.as_ref().transpose(),
y_mat.as_ref(),
1.0_f64,
Par::Seq,
);
let mut alpha = vec![config.alpha_init; d];
let mut beta = config.beta_init;
let mut log_evidences: Vec<f64> = Vec::new();
let mut sigma_mat = Mat::<f64>::zeros(d, d);
let mut mu_vec = vec![0.0_f64; d];
for _iter in 0..config.max_iter {
let mut sigma_inv = Mat::<f64>::from_fn(d, d, |i, j| beta * phi_t_phi[(i, j)]);
for j in 0..d {
sigma_inv[(j, j)] += alpha[j];
}
let llt = sigma_inv
.llt(Side::Lower)
.map_err(|_| BLRError::SingularMatrix)?;
let eye = Mat::<f64>::identity(d, d);
sigma_mat = llt.solve(eye.as_ref());
let mut rhs = phi_t_y.clone();
for i in 0..d {
rhs[(i, 0)] *= beta;
}
let mu_mat = llt.solve(rhs.as_ref());
for i in 0..d {
mu_vec[i] = mu_mat[(i, 0)];
}
let logdet_sigma_inv = cholesky_logdet(&sigma_inv, d)?;
let mut phi_mu = Mat::<f64>::zeros(n, 1);
let mu_mat_ref = Mat::<f64>::from_fn(d, 1, |i, _| mu_vec[i]);
matmul::matmul(
phi_mu.as_mut(),
Accum::Replace,
phi_mat.as_ref(),
mu_mat_ref.as_ref(),
1.0_f64,
Par::Seq,
);
let residual_sq: f64 = (0..n)
.map(|i| {
let r = y[i] - phi_mu[(i, 0)];
r * r
})
.sum();
let gamma: Vec<f64> = (0..d).map(|j| 1.0 - alpha[j] * sigma_mat[(j, j)]).collect();
for j in 0..d {
alpha[j] = (gamma[j] / (mu_vec[j] * mu_vec[j] + 1e-10)).max(1e-8);
}
if config.update_beta {
let gamma_sum: f64 = gamma.iter().sum();
beta = ((n as f64 - gamma_sum) / (residual_sq + 1e-10)).max(1e-8);
}
let lml = log_evidence(n, d, &alpha, beta, &mu_vec, logdet_sigma_inv, residual_sq);
log_evidences.push(lml);
let n_ev = log_evidences.len();
let delta = if n_ev >= 4 {
let mean_curr = 0.5 * (log_evidences[n_ev - 1] + log_evidences[n_ev - 2]);
let mean_prev = 0.5 * (log_evidences[n_ev - 3] + log_evidences[n_ev - 4]);
(mean_curr - mean_prev).abs()
} else if n_ev >= 2 {
(log_evidences[n_ev - 1] - log_evidences[n_ev - 2]).abs()
} else {
f64::INFINITY
};
if delta < config.tol {
break;
}
}
let mu_final: Vec<f64> = mu_vec.clone();
let cov_final: Vec<f64> = {
let sigma_ref = sigma_mat.as_ref();
(0..d)
.flat_map(|i| (0..d).map(move |j| sigma_ref[(i, j)]))
.collect()
};
let posterior = Gaussian::new(mu_final, cov_final)?;
Ok(FittedArd {
posterior,
alpha,
beta,
log_evidences,
n_samples: n,
})
}
pub fn fit_with_prior(
phi: &[f64],
y: &[f64],
n: usize,
d: usize,
config: &ArdConfig,
prior: Option<&BLRPrior>,
) -> Result<FittedArd, BLRError> {
if phi.len() != n * d {
return Err(BLRError::DimMismatch {
expected: n * d,
got: phi.len(),
});
}
if y.len() != n {
return Err(BLRError::DimMismatch {
expected: n,
got: y.len(),
});
}
if let Some(p) = prior {
p.validate()?;
if p.mean.len() != d {
return Err(BLRError::DimMismatch {
expected: d,
got: p.mean.len(),
});
}
}
let phi_mat = Mat::<f64>::from_fn(n, d, |i, j| phi[i * d + j]);
let y_mat = Mat::<f64>::from_fn(n, 1, |i, _| y[i]);
let mut phi_t_phi = Mat::<f64>::zeros(d, d);
matmul::matmul(
phi_t_phi.as_mut(),
Accum::Replace,
phi_mat.as_ref().transpose(),
phi_mat.as_ref(),
1.0_f64,
Par::Seq,
);
let mut phi_t_y = Mat::<f64>::zeros(d, 1);
matmul::matmul(
phi_t_y.as_mut(),
Accum::Replace,
phi_mat.as_ref().transpose(),
y_mat.as_ref(),
1.0_f64,
Par::Seq,
);
let mut alpha: Vec<f64> = prior
.map(|p| p.alphas.clone())
.unwrap_or_else(|| vec![config.alpha_init; d]);
let mut beta = config.beta_init;
let mut log_evidences: Vec<f64> = Vec::new();
let mut sigma_mat = Mat::<f64>::zeros(d, d);
let mut mu_vec: Vec<f64> = prior
.map(|p| p.mean.clone())
.unwrap_or_else(|| vec![0.0f64; d]);
for _iter in 0..config.max_iter {
let mut sigma_inv = Mat::<f64>::from_fn(d, d, |i, j| beta * phi_t_phi[(i, j)]);
for j in 0..d {
sigma_inv[(j, j)] += alpha[j];
}
let llt = sigma_inv
.llt(Side::Lower)
.map_err(|_| BLRError::SingularMatrix)?;
let eye = Mat::<f64>::identity(d, d);
sigma_mat = llt.solve(eye.as_ref());
let mut rhs = phi_t_y.clone();
for i in 0..d {
rhs[(i, 0)] *= beta;
}
let mu_mat = llt.solve(rhs.as_ref());
for i in 0..d {
mu_vec[i] = mu_mat[(i, 0)];
}
let logdet_sigma_inv = cholesky_logdet(&sigma_inv, d)?;
let mut phi_mu = Mat::<f64>::zeros(n, 1);
let mu_mat_ref = Mat::<f64>::from_fn(d, 1, |i, _| mu_vec[i]);
matmul::matmul(
phi_mu.as_mut(),
Accum::Replace,
phi_mat.as_ref(),
mu_mat_ref.as_ref(),
1.0_f64,
Par::Seq,
);
let residual_sq: f64 = (0..n)
.map(|i| {
let r = y[i] - phi_mu[(i, 0)];
r * r
})
.sum();
let gamma: Vec<f64> = (0..d).map(|j| 1.0 - alpha[j] * sigma_mat[(j, j)]).collect();
for j in 0..d {
alpha[j] = (gamma[j] / (mu_vec[j] * mu_vec[j] + 1e-10)).max(1e-8);
}
if config.update_beta {
let gamma_sum: f64 = gamma.iter().sum();
beta = ((n as f64 - gamma_sum) / (residual_sq + 1e-10)).max(1e-8);
}
let lml = log_evidence(n, d, &alpha, beta, &mu_vec, logdet_sigma_inv, residual_sq);
log_evidences.push(lml);
let n_ev = log_evidences.len();
let delta = if n_ev >= 4 {
let mean_curr = 0.5 * (log_evidences[n_ev - 1] + log_evidences[n_ev - 2]);
let mean_prev = 0.5 * (log_evidences[n_ev - 3] + log_evidences[n_ev - 4]);
(mean_curr - mean_prev).abs()
} else if n_ev >= 2 {
(log_evidences[n_ev - 1] - log_evidences[n_ev - 2]).abs()
} else {
f64::INFINITY
};
if delta < config.tol {
break;
}
}
let mu_final: Vec<f64> = mu_vec.clone();
let cov_final: Vec<f64> = {
let sigma_ref = sigma_mat.as_ref();
(0..d)
.flat_map(|i| (0..d).map(move |j| sigma_ref[(i, j)]))
.collect()
};
let posterior = Gaussian::new(mu_final, cov_final)?;
Ok(FittedArd {
posterior,
alpha,
beta,
log_evidences,
n_samples: n,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ard_config_defaults() {
let cfg = ArdConfig::default();
assert_eq!(cfg.alpha_init, 1.0);
assert_eq!(cfg.beta_init, 1.0);
assert_eq!(cfg.max_iter, 100);
assert_eq!(cfg.tol, 1e-5);
assert!(cfg.update_beta);
}
#[test]
fn test_log_evidence_helper() {
let lml = log_evidence(10, 3, &[1.0; 3], 1.0, &[0.0; 3], 5.0, 2.0);
assert!(lml.is_finite(), "log_evidence = {lml}");
}
#[test]
fn test_blr_prior_valid() {
let d = 3;
let prior = BLRPrior {
mean: vec![0.0; d],
cov: vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], alphas: vec![1.0; d],
};
assert!(prior.validate().is_ok());
}
#[test]
fn test_blr_prior_invalid_dimensions() {
let prior = BLRPrior {
mean: vec![0.0; 3],
cov: vec![1.0, 0.0, 0.0, 1.0], alphas: vec![1.0; 3],
};
assert!(prior.validate().is_err());
}
#[test]
fn test_blr_prior_not_psd() {
let d = 2;
let prior = BLRPrior {
mean: vec![0.0; d],
cov: vec![-1.0, 0.0, 0.0, -1.0], alphas: vec![1.0; d],
};
assert!(matches!(prior.validate(), Err(BLRError::SingularMatrix)));
}
#[test]
fn test_fit_with_prior_none_equals_fit() {
let phi: Vec<f64> = vec![1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125];
let y: Vec<f64> = vec![1.0, 2.0, 0.5];
let config = ArdConfig::default();
let r1 = fit(&phi, &y, 3, 3, &config).unwrap();
let r2 = fit_with_prior(&phi, &y, 3, 3, &config, None).unwrap();
assert_eq!(r1.alpha.len(), r2.alpha.len());
for (a1, a2) in r1.alpha.iter().zip(r2.alpha.iter()) {
assert!((a1 - a2).abs() < 1e-10, "alpha mismatch: {a1} vs {a2}");
}
assert!((r1.beta - r2.beta).abs() < 1e-10);
}
#[test]
fn test_fit_with_prior_some_compiles_and_runs() {
let d = 3;
let phi: Vec<f64> = vec![
1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125, 1.5, 0.75, 0.3, 0.8, 0.4, 0.2,
];
let y: Vec<f64> = vec![1.0, 2.0, 0.5, 1.5, 0.8];
let config = ArdConfig::default();
let prior = BLRPrior {
mean: vec![0.5; d],
cov: vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
alphas: vec![0.5; d],
};
let result = fit_with_prior(&phi, &y, 5, d, &config, Some(&prior));
assert!(
result.is_ok(),
"fit_with_prior should succeed: {:?}",
result.err()
);
let fitted = result.unwrap();
assert!(fitted.noise_std() > 0.0);
assert_eq!(fitted.alpha.len(), d);
}
#[test]
fn test_fit_with_prior_convergence_faster() {
let d = 3;
let n = 5;
let phi: Vec<f64> = vec![
1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125, 1.5, 0.75, 0.3, 0.8, 0.4, 0.2,
];
let y: Vec<f64> = vec![1.0, 2.0, 0.5, 1.5, 0.8];
let config = ArdConfig {
max_iter: 200,
tol: 1e-9,
..ArdConfig::default()
};
let baseline = fit_with_prior(&phi, &y, n, d, &config, None).unwrap();
let prior = BLRPrior {
mean: baseline.posterior.mean.clone(),
cov: baseline.posterior.cov.clone(),
alphas: baseline.alpha.clone(),
};
let informed = fit_with_prior(&phi, &y, n, d, &config, Some(&prior)).unwrap();
assert!(informed.noise_std() > 0.0);
assert!(
informed.log_evidences.len() <= baseline.log_evidences.len(),
"informed iterations {} should be <= baseline iterations {}",
informed.log_evidences.len(),
baseline.log_evidences.len()
);
}
}