use crate::error::{VisionError, VisionResult};
#[derive(Debug, Clone)]
pub struct BnParams {
pub gamma: Vec<f32>,
pub beta: Vec<f32>,
pub mean: Vec<f32>,
pub var: Vec<f32>,
pub eps: f32,
}
pub fn fold_bn_into_linear(
weight: &[f32],
bias: &[f32],
bn: &BnParams,
out_features: usize,
in_features: usize,
) -> VisionResult<(Vec<f32>, Vec<f32>)> {
validate_bn(bn, out_features)?;
if weight.len() != out_features * in_features {
return Err(VisionError::DimensionMismatch {
expected: out_features * in_features,
got: weight.len(),
});
}
if bias.len() != out_features {
return Err(VisionError::DimensionMismatch {
expected: out_features,
got: bias.len(),
});
}
let mut w_new = vec![0.0_f32; out_features * in_features];
let mut b_new = vec![0.0_f32; out_features];
for i in 0..out_features {
let sigma = (bn.var[i] + bn.eps).sqrt();
if !sigma.is_finite() || sigma == 0.0 {
return Err(VisionError::NonFinite("BN sigma is zero or non-finite"));
}
let scale = bn.gamma[i] / sigma;
for j in 0..in_features {
w_new[i * in_features + j] = weight[i * in_features + j] * scale;
}
b_new[i] = (bias[i] - bn.mean[i]) * scale + bn.beta[i];
}
Ok((w_new, b_new))
}
pub fn fold_bn_into_conv(
weight: &[f32],
bias: &[f32],
bn: &BnParams,
out_ch: usize,
) -> VisionResult<(Vec<f32>, Vec<f32>)> {
validate_bn(bn, out_ch)?;
if bias.len() != out_ch {
return Err(VisionError::DimensionMismatch {
expected: out_ch,
got: bias.len(),
});
}
if weight.len() % out_ch != 0 {
return Err(VisionError::DimensionMismatch {
expected: 0, got: weight.len() % out_ch,
});
}
let slice_len = weight.len() / out_ch;
let mut w_new = vec![0.0_f32; weight.len()];
let mut b_new = vec![0.0_f32; out_ch];
for i in 0..out_ch {
let sigma = (bn.var[i] + bn.eps).sqrt();
if !sigma.is_finite() || sigma == 0.0 {
return Err(VisionError::NonFinite("BN sigma is zero or non-finite"));
}
let scale = bn.gamma[i] / sigma;
for j in 0..slice_len {
w_new[i * slice_len + j] = weight[i * slice_len + j] * scale;
}
b_new[i] = (bias[i] - bn.mean[i]) * scale + bn.beta[i];
}
Ok((w_new, b_new))
}
pub fn verify_bn_fold(
x: &[f32],
weight: &[f32],
bias: &[f32],
bn: &BnParams,
out_features: usize,
in_features: usize,
n_samples: usize,
) -> VisionResult<f32> {
if x.len() != n_samples * in_features {
return Err(VisionError::DimensionMismatch {
expected: n_samples * in_features,
got: x.len(),
});
}
let (w_new, b_new) = fold_bn_into_linear(weight, bias, bn, out_features, in_features)?;
let mut max_err = 0.0_f32;
for s in 0..n_samples {
let x_row = &x[s * in_features..(s + 1) * in_features];
for i in 0..out_features {
let y_orig_i: f32 = bias[i]
+ (0..in_features)
.map(|j| weight[i * in_features + j] * x_row[j])
.sum::<f32>();
let sigma = (bn.var[i] + bn.eps).sqrt();
let bn_out_i = (y_orig_i - bn.mean[i]) * bn.gamma[i] / sigma + bn.beta[i];
let y_fold_i: f32 = b_new[i]
+ (0..in_features)
.map(|j| w_new[i * in_features + j] * x_row[j])
.sum::<f32>();
let err = (bn_out_i - y_fold_i).abs();
if err > max_err {
max_err = err;
}
}
}
Ok(max_err)
}
fn validate_bn(bn: &BnParams, n_channels: usize) -> VisionResult<()> {
if bn.var.iter().any(|&v| v < -bn.eps) {
return Err(VisionError::NonFinite("BN variance is negative"));
}
for vec in [&bn.gamma, &bn.beta, &bn.mean, &bn.var] {
if vec.len() != n_channels {
return Err(VisionError::DimensionMismatch {
expected: n_channels,
got: vec.len(),
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn default_bn(c: usize) -> BnParams {
BnParams {
gamma: vec![1.0_f32; c],
beta: vec![0.0_f32; c],
mean: vec![0.0_f32; c],
var: vec![1.0_f32; c],
eps: 1e-5,
}
}
fn rand_vec(n: usize, seed: u64) -> Vec<f32> {
let mut r = LcgRng::new(seed);
(0..n).map(|_| r.next_f32() * 2.0 - 1.0).collect()
}
#[test]
fn fold_output_shape() {
let out_f = 4;
let in_f = 8;
let w = vec![1.0_f32; out_f * in_f];
let b = vec![0.0_f32; out_f];
let bn = default_bn(out_f);
let (w_new, b_new) = fold_bn_into_linear(&w, &b, &bn, out_f, in_f)
.expect("fold_bn_into_linear should succeed");
assert_eq!(w_new.len(), out_f * in_f);
assert_eq!(b_new.len(), out_f);
}
#[test]
fn fold_preserves_computation() {
let out_f = 4;
let in_f = 3;
let n = 5;
let w = rand_vec(out_f * in_f, 1);
let b = rand_vec(out_f, 2);
let bn = BnParams {
gamma: rand_vec(out_f, 3),
beta: rand_vec(out_f, 4),
mean: rand_vec(out_f, 5),
var: vec![0.5_f32, 1.0, 2.0, 0.8],
eps: 1e-5,
};
let x = rand_vec(n * in_f, 6);
let err =
verify_bn_fold(&x, &w, &b, &bn, out_f, in_f, n).expect("verify_bn_fold should succeed");
assert!(err < 1e-5, "max fold error = {err}");
}
#[test]
fn zero_mean_fold() {
let out_f = 3;
let in_f = 2;
let w = vec![1.0_f32; out_f * in_f];
let b = vec![0.5_f32; out_f];
let bn = BnParams {
gamma: vec![2.0_f32; out_f],
beta: vec![0.0_f32; out_f],
mean: vec![0.0_f32; out_f],
var: vec![1.0_f32; out_f],
eps: 1e-5,
};
let (w_new, b_new) = fold_bn_into_linear(&w, &b, &bn, out_f, in_f)
.expect("fold_bn_into_linear should succeed");
let sigma = (1.0_f32 + 1e-5).sqrt(); let expected_w = 2.0 / sigma;
let expected_b = (0.5 - 0.0) * 2.0 / sigma; for &v in &w_new {
assert!(
(v - expected_w).abs() < 1e-4,
"w_new={v}, expected={expected_w}"
);
}
for &v in &b_new {
assert!(
(v - expected_b).abs() < 1e-4,
"b_new={v}, expected={expected_b}"
);
}
}
#[test]
fn unit_var_fold() {
let out_f = 2;
let in_f = 2;
let w = vec![1.0_f32, 0.0, 0.0, 1.0]; let b = vec![0.0_f32; out_f];
let bn = default_bn(out_f); let (w_new, b_new) = fold_bn_into_linear(&w, &b, &bn, out_f, in_f)
.expect("fold_bn_into_linear should succeed");
for (orig, folded) in w.iter().zip(&w_new) {
assert!((orig - folded).abs() < 1e-5);
}
for &v in &b_new {
assert!(v.abs() < 1e-5);
}
}
#[test]
fn gamma_1_beta_0_simplifies() {
let out_f = 2;
let in_f = 2;
let w = rand_vec(out_f * in_f, 10);
let b = rand_vec(out_f, 11);
let bn = BnParams {
gamma: vec![1.0_f32; out_f],
beta: vec![0.0_f32; out_f],
mean: vec![0.5_f32; out_f],
var: vec![4.0_f32; out_f],
eps: 1e-5,
};
let (w_new, b_new) = fold_bn_into_linear(&w, &b, &bn, out_f, in_f)
.expect("fold_bn_into_linear should succeed");
let sigma = (4.0_f32 + 1e-5).sqrt();
for (orig, folded) in w.iter().zip(&w_new) {
let expect = orig / sigma;
assert!((folded - expect).abs() < 1e-5, "{folded} vs {expect}");
}
for (i, &bv) in b_new.iter().enumerate() {
let expect = (b[i] - 0.5) / sigma;
assert!((bv - expect).abs() < 1e-5, "b_new[{i}]={bv} vs {expect}");
}
}
#[test]
fn var_negative_error() {
let out_f = 2;
let in_f = 2;
let w = vec![1.0_f32; out_f * in_f];
let b = vec![0.0_f32; out_f];
let bn = BnParams {
gamma: vec![1.0_f32; out_f],
beta: vec![0.0_f32; out_f],
mean: vec![0.0_f32; out_f],
var: vec![-1.0_f32, 1.0], eps: 1e-5,
};
let result = fold_bn_into_linear(&w, &b, &bn, out_f, in_f);
assert!(result.is_err());
}
#[test]
fn bn_fold_conv_shape() {
let out_ch = 4;
let in_ch = 3;
let kh = 3;
let kw = 3;
let w = rand_vec(out_ch * in_ch * kh * kw, 20);
let b = vec![0.0_f32; out_ch];
let bn = default_bn(out_ch);
let (w_new, b_new) =
fold_bn_into_conv(&w, &b, &bn, out_ch).expect("fold_bn_into_conv should succeed");
assert_eq!(w_new.len(), out_ch * in_ch * kh * kw);
assert_eq!(b_new.len(), out_ch);
}
#[test]
fn verify_error_small() {
let out_f = 6;
let in_f = 4;
let n = 8;
let w = rand_vec(out_f * in_f, 30);
let b = rand_vec(out_f, 31);
let bn = BnParams {
gamma: rand_vec(out_f, 32),
beta: rand_vec(out_f, 33),
mean: rand_vec(out_f, 34),
var: vec![0.1_f32, 0.5, 1.0, 2.0, 0.3, 1.5],
eps: 1e-5,
};
let x = rand_vec(n * in_f, 35);
let err =
verify_bn_fold(&x, &w, &b, &bn, out_f, in_f, n).expect("verify_bn_fold should succeed");
assert!(err < 1e-4, "max error = {err}");
}
#[test]
fn batch_size_varies() {
let out_f = 3;
let in_f = 2;
let w = rand_vec(out_f * in_f, 40);
let b = rand_vec(out_f, 41);
let bn = default_bn(out_f);
for &n in &[1_usize, 4, 16] {
let x = rand_vec(n * in_f, n as u64);
let err = verify_bn_fold(&x, &w, &b, &bn, out_f, in_f, n)
.expect("verify_bn_fold should succeed");
assert!(err < 1e-4, "n={n} error={err}");
}
}
#[test]
fn bn_fold_conv_gamma_scales_weights() {
let out_ch = 2;
let w = vec![1.0_f32, 1.0]; let b = vec![0.0_f32; out_ch];
let bn = BnParams {
gamma: vec![2.0_f32, 3.0],
beta: vec![0.0_f32; out_ch],
mean: vec![0.0_f32; out_ch],
var: vec![1.0_f32; out_ch],
eps: 1e-5,
};
let (w_new, _) =
fold_bn_into_conv(&w, &b, &bn, out_ch).expect("fold_bn_into_conv should succeed");
let sigma = (1.0_f32 + 1e-5).sqrt();
assert!((w_new[0] - 2.0 / sigma).abs() < 1e-4);
assert!((w_new[1] - 3.0 / sigma).abs() < 1e-4);
}
}