use scirs2_core::ndarray::{Array1, Array2};
use crate::error::{StatsError, StatsResult};
use super::types::{BNNConfig, BNNPosterior, CovarianceType, PredictiveDistribution};
#[derive(Debug, Clone)]
pub struct LaplaceApproximation {
posterior: BNNPosterior,
config: BNNConfig,
}
impl LaplaceApproximation {
pub fn fit(
weights: &Array1<f64>,
jacobian: &Array2<f64>,
residuals: &Array1<f64>,
config: &BNNConfig,
) -> StatsResult<Self> {
let n = jacobian.nrows();
let d = jacobian.ncols();
if weights.len() != d {
return Err(StatsError::dimension_mismatch(format!(
"weights length {} != jacobian columns {}",
weights.len(),
d
)));
}
if residuals.len() != n {
return Err(StatsError::dimension_mismatch(format!(
"residuals length {} != jacobian rows {}",
residuals.len(),
n
)));
}
if d == 0 {
return Err(StatsError::invalid_argument(
"Number of parameters must be > 0",
));
}
let jtj = jacobian.t().dot(jacobian);
let mut hessian = jtj;
for i in 0..d {
hessian[[i, i]] += config.prior_precision;
}
let cov = cholesky_inverse(&hessian)?;
let sse: f64 = residuals.iter().map(|r| r * r).sum();
let log_likelihood = -0.5 * sse; let log_prior = -0.5 * config.prior_precision * weights.iter().map(|w| w * w).sum::<f64>();
let log_det_h = cholesky_log_det(&hessian)?;
let log_marginal =
log_likelihood + log_prior + 0.5 * (d as f64) * (2.0 * std::f64::consts::PI).ln()
- 0.5 * log_det_h;
let posterior = BNNPosterior {
mean: weights.clone(),
covariance_type: CovarianceType::Full(cov),
log_marginal_likelihood: log_marginal,
};
Ok(Self {
posterior,
config: config.clone(),
})
}
pub fn predict(
&self,
jacobian_test: &Array2<f64>,
mean_prediction: &Array1<f64>,
) -> StatsResult<PredictiveDistribution> {
let n_test = jacobian_test.nrows();
let d = jacobian_test.ncols();
if mean_prediction.len() != n_test {
return Err(StatsError::dimension_mismatch(format!(
"mean_prediction length {} != jacobian_test rows {}",
mean_prediction.len(),
n_test
)));
}
let cov = match &self.posterior.covariance_type {
CovarianceType::Full(c) => c,
_ => {
return Err(StatsError::computation(
"Laplace predict requires Full covariance",
))
}
};
if cov.nrows() != d || cov.ncols() != d {
return Err(StatsError::dimension_mismatch(format!(
"Covariance shape [{}, {}] incompatible with Jacobian columns {}",
cov.nrows(),
cov.ncols(),
d
)));
}
let j_sigma = jacobian_test.dot(cov); let mut variance = Array1::zeros(n_test);
for i in 0..n_test {
let mut v = 0.0;
for j in 0..d {
v += j_sigma[[i, j]] * jacobian_test[[i, j]];
}
variance[i] = v;
}
Ok(PredictiveDistribution {
mean: mean_prediction.clone(),
variance,
samples: None,
})
}
pub fn log_marginal_likelihood(&self) -> f64 {
self.posterior.log_marginal_likelihood
}
pub fn posterior(&self) -> &BNNPosterior {
&self.posterior
}
pub fn config(&self) -> &BNNConfig {
&self.config
}
}
pub fn kfac_factors(
activations: &Array2<f64>,
gradients: &Array2<f64>,
) -> StatsResult<(Array2<f64>, Array2<f64>)> {
let n_a = activations.nrows();
let n_g = gradients.nrows();
if n_a != n_g {
return Err(StatsError::dimension_mismatch(format!(
"activations rows {} != gradients rows {}",
n_a, n_g
)));
}
if n_a == 0 {
return Err(StatsError::invalid_argument(
"Need at least 1 sample for KFAC",
));
}
let n = n_a as f64;
let a_factor = activations.t().dot(activations) / n;
let b_factor = gradients.t().dot(gradients) / n;
Ok((a_factor, b_factor))
}
fn cholesky_decompose(a: &Array2<f64>) -> StatsResult<Array2<f64>> {
let n = a.nrows();
if n != a.ncols() {
return Err(StatsError::dimension_mismatch("Matrix must be square"));
}
let mut l = Array2::<f64>::zeros((n, n));
for j in 0..n {
let mut s = 0.0;
for k in 0..j {
s += l[[j, k]] * l[[j, k]];
}
let diag = a[[j, j]] - s;
if diag <= 0.0 {
return Err(StatsError::computation(format!(
"Matrix is not positive definite (pivot {} at index {})",
diag, j
)));
}
l[[j, j]] = diag.sqrt();
for i in (j + 1)..n {
let mut s2 = 0.0;
for k in 0..j {
s2 += l[[i, k]] * l[[j, k]];
}
l[[i, j]] = (a[[i, j]] - s2) / l[[j, j]];
}
}
Ok(l)
}
fn cholesky_inverse(a: &Array2<f64>) -> StatsResult<Array2<f64>> {
let l = cholesky_decompose(a)?;
let n = l.nrows();
let mut l_inv = Array2::<f64>::zeros((n, n));
for i in 0..n {
l_inv[[i, i]] = 1.0 / l[[i, i]];
for j in (0..i).rev() {
let mut s = 0.0;
for k in (j + 1)..=i {
s += l[[i, k]] * l_inv[[k, j]];
}
l_inv[[i, j]] = -s / l[[i, i]]; }
}
let inv = l_inv.t().dot(&l_inv);
Ok(inv)
}
fn cholesky_log_det(a: &Array2<f64>) -> StatsResult<f64> {
let l = cholesky_decompose(a)?;
let n = l.nrows();
let mut log_det = 0.0;
for i in 0..n {
log_det += l[[i, i]].ln();
}
Ok(2.0 * log_det)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array1, Array2};
#[test]
fn test_cholesky_identity() {
let eye = Array2::from_diag(&Array1::from_vec(vec![1.0, 1.0, 1.0]));
let l = cholesky_decompose(&eye).expect("Cholesky of identity should succeed");
for i in 0..3 {
assert!((l[[i, i]] - 1.0).abs() < 1e-12);
}
}
#[test]
fn test_cholesky_inverse_identity() {
let eye = Array2::from_diag(&Array1::from_vec(vec![2.0, 3.0, 5.0]));
let inv = cholesky_inverse(&eye).expect("inverse of diagonal should succeed");
assert!((inv[[0, 0]] - 0.5).abs() < 1e-12);
assert!((inv[[1, 1]] - 1.0 / 3.0).abs() < 1e-12);
assert!((inv[[2, 2]] - 0.2).abs() < 1e-12);
}
#[test]
fn test_laplace_quadratic_loss() {
let x_data = array![[1.0], [2.0], [3.0]];
let w_map = array![2.0];
let residuals = array![0.1, -0.05, 0.02];
let config = BNNConfig {
prior_precision: 1.0,
..BNNConfig::default()
};
let lap = LaplaceApproximation::fit(&w_map, &x_data, &residuals, &config)
.expect("Laplace fit should succeed");
match &lap.posterior().covariance_type {
CovarianceType::Full(cov) => {
let expected_var = 1.0 / 15.0;
assert!(
(cov[[0, 0]] - expected_var).abs() < 1e-10,
"Expected variance {}, got {}",
expected_var,
cov[[0, 0]]
);
}
_ => panic!("Expected Full covariance"),
}
}
#[test]
fn test_laplace_predict_uncertainty_grows() {
let x_data = array![[1.0], [2.0]];
let w_map = array![1.0];
let residuals = array![0.0, 0.0];
let config = BNNConfig::default();
let lap = LaplaceApproximation::fit(&w_map, &x_data, &residuals, &config).expect("fit");
let j_near = array![[1.5]]; let j_far = array![[10.0]];
let pred_near = lap.predict(&j_near, &array![1.5]).expect("predict near");
let pred_far = lap.predict(&j_far, &array![10.0]).expect("predict far");
assert!(
pred_far.variance[0] > pred_near.variance[0],
"Uncertainty should grow: near={}, far={}",
pred_near.variance[0],
pred_far.variance[0]
);
}
#[test]
fn test_laplace_dimension_mismatch() {
let w = array![1.0, 2.0];
let j = array![[1.0]]; let r = array![0.1];
let config = BNNConfig::default();
assert!(LaplaceApproximation::fit(&w, &j, &r, &config).is_err());
}
#[test]
fn test_kfac_factors_symmetric_psd() {
let activations = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let gradients = array![[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]];
let (a, b) = kfac_factors(&activations, &gradients).expect("KFAC should succeed");
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 2);
assert_eq!(b.nrows(), 3);
assert_eq!(b.ncols(), 3);
assert!((a[[0, 1]] - a[[1, 0]]).abs() < 1e-12);
for i in 0..a.nrows() {
assert!(a[[i, i]] >= 0.0);
}
for i in 0..b.nrows() {
assert!(b[[i, i]] >= 0.0);
}
}
#[test]
fn test_kfac_row_mismatch() {
let a = array![[1.0], [2.0]];
let g = array![[1.0], [2.0], [3.0]];
assert!(kfac_factors(&a, &g).is_err());
}
#[test]
fn test_log_marginal_likelihood_finite() {
let j = array![[1.0, 0.0], [0.0, 1.0]];
let w = array![1.0, 1.0];
let r = array![0.1, -0.1];
let config = BNNConfig::default();
let lap = LaplaceApproximation::fit(&w, &j, &r, &config).expect("fit");
let lml = lap.log_marginal_likelihood();
assert!(lml.is_finite(), "log marginal likelihood should be finite");
}
}