use crate::error::{StatsError, StatsResult};
use super::types::{BnnApproxResult, HessianMethod, LaplaceConfig};
pub fn diagonal_ggn(grad_matrix: &[Vec<f64>]) -> StatsResult<Vec<f64>> {
if grad_matrix.is_empty() {
return Err(StatsError::invalid_argument(
"diagonal_ggn: grad_matrix must not be empty",
));
}
let n_params = grad_matrix[0].len();
if n_params == 0 {
return Err(StatsError::invalid_argument(
"diagonal_ggn: each gradient vector must have length > 0",
));
}
for (i, row) in grad_matrix.iter().enumerate() {
if row.len() != n_params {
return Err(StatsError::dimension_mismatch(format!(
"grad_matrix row {} has length {} ≠ n_params {}",
i,
row.len(),
n_params
)));
}
}
let mut ggn = vec![0.0f64; n_params];
for row in grad_matrix {
for (j, &g) in row.iter().enumerate() {
ggn[j] += g * g;
}
}
Ok(ggn)
}
pub fn fd_per_sample_gradients(
weights: &[f64],
loss_fn: &dyn Fn(&[f64]) -> Vec<f64>,
fd_step: f64,
) -> StatsResult<Vec<Vec<f64>>> {
if weights.is_empty() {
return Err(StatsError::invalid_argument(
"fd_per_sample_gradients: weights must not be empty",
));
}
if fd_step <= 0.0 {
return Err(StatsError::invalid_argument(
"fd_per_sample_gradients: fd_step must be > 0",
));
}
let losses_at_w = loss_fn(weights);
let n_data = losses_at_w.len();
if n_data == 0 {
return Err(StatsError::invalid_argument(
"fd_per_sample_gradients: loss_fn returned empty vector",
));
}
let n_params = weights.len();
let mut grad_matrix = vec![vec![0.0f64; n_params]; n_data];
let mut w_fwd = weights.to_vec();
let mut w_bwd = weights.to_vec();
for j in 0..n_params {
w_fwd[j] = weights[j] + fd_step;
w_bwd[j] = weights[j] - fd_step;
let l_fwd = loss_fn(&w_fwd);
let l_bwd = loss_fn(&w_bwd);
for i in 0..n_data {
grad_matrix[i][j] = (l_fwd[i] - l_bwd[i]) / (2.0 * fd_step);
}
w_fwd[j] = weights[j];
w_bwd[j] = weights[j];
}
Ok(grad_matrix)
}
pub fn posterior_variance_from_ggn(ggn_diag: &[f64], damping: f64) -> StatsResult<Vec<f64>> {
if ggn_diag.is_empty() {
return Err(StatsError::invalid_argument(
"posterior_variance_from_ggn: ggn_diag must not be empty",
));
}
let var: Vec<f64> = ggn_diag
.iter()
.map(|&h| {
let denom = h + damping;
if denom <= 0.0 {
1.0 / damping.max(1e-12)
} else {
1.0 / denom
}
})
.collect();
Ok(var)
}
pub fn predict_mean_linear(x: &[f64], weights: &[f64]) -> StatsResult<f64> {
if x.len() != weights.len() {
return Err(StatsError::dimension_mismatch(format!(
"predict_mean: x.len()={} ≠ weights.len()={}",
x.len(),
weights.len()
)));
}
Ok(x.iter().zip(weights).map(|(&xi, &wi)| xi * wi).sum())
}
pub fn predict_variance_linear(x: &[f64], posterior_var: &[f64]) -> StatsResult<f64> {
if x.len() != posterior_var.len() {
return Err(StatsError::dimension_mismatch(format!(
"predict_variance: x.len()={} ≠ posterior_var.len()={}",
x.len(),
posterior_var.len()
)));
}
Ok(x.iter()
.zip(posterior_var)
.map(|(&xi, &vi)| xi * xi * vi)
.sum())
}
pub fn fit_laplace(
map_weights: &[f64],
loss_fn: &dyn Fn(&[f64]) -> Vec<f64>,
config: &LaplaceConfig,
) -> StatsResult<BnnApproxResult> {
if map_weights.is_empty() {
return Err(StatsError::invalid_argument(
"fit_laplace: map_weights must not be empty",
));
}
let grad_matrix = match config.hessian_method {
HessianMethod::GGN | HessianMethod::Diagonal | HessianMethod::KFAC => {
fd_per_sample_gradients(map_weights, loss_fn, config.fd_step)?
}
_ => {
fd_per_sample_gradients(map_weights, loss_fn, config.fd_step)?
}
};
let ggn = diagonal_ggn(&grad_matrix)?;
let posterior_var = posterior_variance_from_ggn(&ggn, config.damping)?;
Ok(BnnApproxResult {
mean_weights: map_weights.to_vec(),
uncertainty: posterior_var,
method: format!("Laplace-{:?}", config.hessian_method),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diagonal_laplace_squared_grads() {
let grads = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let ggn = diagonal_ggn(&grads).expect("ok");
assert!((ggn[0] - 10.0).abs() < 1e-12);
assert!((ggn[1] - 20.0).abs() < 1e-12);
}
#[test]
fn test_diagonal_ggn_single_sample() {
let grads = vec![vec![3.0, -1.0, 2.0]];
let ggn = diagonal_ggn(&grads).expect("ok");
assert!((ggn[0] - 9.0).abs() < 1e-12);
assert!((ggn[1] - 1.0).abs() < 1e-12);
assert!((ggn[2] - 4.0).abs() < 1e-12);
}
#[test]
fn test_laplace_posterior_variance_positive() {
let ggn = vec![1.0, 0.0, 5.0];
let var = posterior_variance_from_ggn(&ggn, 1.0).expect("ok");
for &v in &var {
assert!(v > 0.0, "variance should be positive, got {v}");
}
assert!((var[0] - 0.5).abs() < 1e-12);
assert!((var[1] - 1.0).abs() < 1e-12);
assert!((var[2] - 1.0 / 6.0).abs() < 1e-12);
}
#[test]
fn test_laplace_predict_variance_finite() {
let x = vec![1.0, 2.0];
let posterior_var = vec![0.5, 0.25];
let var = predict_variance_linear(&x, &posterior_var).expect("ok");
assert!(var.is_finite(), "variance should be finite");
assert!((var - 1.5).abs() < 1e-12, "expected 1.5, got {var}");
}
#[test]
fn test_predict_mean_linear() {
let x = vec![1.0, 2.0, 3.0];
let w = vec![1.0, 1.0, 1.0];
let pred = predict_mean_linear(&x, &w).expect("ok");
assert!((pred - 6.0).abs() < 1e-12);
}
#[test]
fn test_laplace_uncertainty_increases_far_from_data() {
let posterior_var = vec![0.1]; let x_near = vec![1.0];
let x_far = vec![10.0];
let var_near = predict_variance_linear(&x_near, &posterior_var).expect("near");
let var_far = predict_variance_linear(&x_far, &posterior_var).expect("far");
assert!(
var_far > var_near,
"Uncertainty should be higher far from data: near={var_near}, far={var_far}"
);
}
#[test]
fn test_fit_laplace_basic() {
let x_data = vec![1.0f64, 2.0, 3.0];
let y_data = vec![1.0f64, 2.0, 3.0];
let loss_fn = move |w: &[f64]| -> Vec<f64> {
x_data
.iter()
.zip(&y_data)
.map(|(&xi, &yi)| (yi - w[0] * xi).powi(2))
.collect()
};
let config = LaplaceConfig::default();
let result = fit_laplace(&[1.0], &loss_fn, &config).expect("fit");
assert_eq!(result.mean_weights.len(), 1);
assert_eq!(result.uncertainty.len(), 1);
assert!(result.uncertainty[0] > 0.0, "variance must be positive");
assert!(result.uncertainty[0].is_finite(), "variance must be finite");
}
#[test]
fn test_fit_laplace_empty_weights_error() {
let loss_fn = |_: &[f64]| vec![1.0];
let config = LaplaceConfig::default();
assert!(fit_laplace(&[], &loss_fn, &config).is_err());
}
#[test]
fn test_diagonal_ggn_empty_error() {
assert!(diagonal_ggn(&[]).is_err());
}
}