use alloc::vec::Vec;
use crate::math;
#[derive(Debug, Clone)]
pub struct BCNorm {
eps: f64,
gamma: f64,
}
impl BCNorm {
pub fn new(_n_state: usize) -> Self {
Self {
eps: 1e-6, gamma: 1.0,
}
}
pub fn with_params(eps: f64, gamma: f64) -> Self {
assert!(eps > 0.0, "BCNorm eps must be > 0, got {}", eps);
assert!(gamma > 0.0, "BCNorm gamma must be > 0, got {}", gamma);
Self { eps, gamma }
}
pub fn normalize(&self, v: &[f64]) -> Vec<f64> {
if v.is_empty() {
return Vec::new();
}
let n = v.len() as f64;
let mean_sq: f64 = v.iter().map(|&vi| vi * vi).sum::<f64>() / n;
let rms = math::sqrt(mean_sq + self.eps);
let scale = self.gamma / rms;
v.iter().map(|&vi| vi * scale).collect()
}
pub fn normalize_into(&self, v: &[f64], out: &mut [f64]) {
debug_assert_eq!(
v.len(),
out.len(),
"BCNorm: input and output slices must have equal length"
);
if v.is_empty() {
return;
}
let n = v.len() as f64;
let mean_sq: f64 = v.iter().map(|&vi| vi * vi).sum::<f64>() / n;
let rms = math::sqrt(mean_sq + self.eps);
let scale = self.gamma / rms;
for (o, &vi) in out.iter_mut().zip(v.iter()) {
*o = vi * scale;
}
}
#[inline]
pub fn eps(&self) -> f64 {
self.eps
}
#[inline]
pub fn gamma(&self) -> f64 {
self.gamma
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn bcnorm_scale_invariance() {
let norm = BCNorm::new(4);
let v = vec![1.0, -2.0, 3.0, -1.0];
let v_scaled: Vec<f64> = v.iter().map(|&x| x * 100.0).collect();
let normed = norm.normalize(&v);
let normed_scaled = norm.normalize(&v_scaled);
for (i, (&a, &b)) in normed.iter().zip(normed_scaled.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"BCNorm must be scale-invariant: norm(v)[{}]={} vs norm(100v)[{}]={}",
i,
a,
i,
b
);
}
}
#[test]
fn bcnorm_output_has_unit_rms_times_gamma() {
let norm = BCNorm::new(4);
let v = vec![1.0, 2.0, 3.0, 4.0];
let normed = norm.normalize(&v);
let rms: f64 = math::sqrt(normed.iter().map(|&x| x * x).sum::<f64>() / normed.len() as f64);
assert!(
(rms - norm.gamma()).abs() < 1e-6,
"RMS of BCNorm output should be gamma={}, got {}",
norm.gamma(),
rms
);
}
#[test]
fn bcnorm_custom_gamma_scales_output() {
let norm2 = BCNorm::with_params(1e-6, 2.0);
let norm1 = BCNorm::new(4);
let v = vec![1.0, -1.0, 2.0, -2.0];
let out1 = norm1.normalize(&v);
let out2 = norm2.normalize(&v);
for (i, (&a, &b)) in out1.iter().zip(out2.iter()).enumerate() {
assert!(
(b - 2.0 * a).abs() < 1e-10,
"gamma=2 should double the output at index {}: a={}, b={}",
i,
a,
b
);
}
}
#[test]
fn bcnorm_zero_vector_is_finite() {
let norm = BCNorm::new(4);
let v = vec![0.0, 0.0, 0.0, 0.0];
let normed = norm.normalize(&v);
for &x in &normed {
assert!(
x.is_finite(),
"BCNorm of zero vector must be finite: got {}",
x
);
assert_eq!(x, 0.0, "BCNorm of zero vector must be zero, got {}", x);
}
}
#[test]
fn bcnorm_empty_vector() {
let norm = BCNorm::new(0);
let normed = norm.normalize(&[]);
assert!(
normed.is_empty(),
"BCNorm of empty slice must return empty vec"
);
}
#[test]
fn bcnorm_normalize_into_matches_normalize() {
let norm = BCNorm::new(4);
let v = vec![1.0, -2.0, 3.0, -4.0];
let normed = norm.normalize(&v);
let mut out = vec![0.0; 4];
norm.normalize_into(&v, &mut out);
for (i, (&a, &b)) in normed.iter().zip(out.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-12,
"normalize_into must match normalize at index {}: {} vs {}",
i,
a,
b
);
}
}
#[test]
fn bcnorm_single_element() {
let norm = BCNorm::new(1);
let v = vec![5.0];
let normed = norm.normalize(&v);
assert!(
normed.len() == 1 && normed[0].is_finite(),
"single-element BCNorm must be finite"
);
assert!(
(normed[0] - norm.gamma()).abs() < 1e-4,
"BCNorm of [5.0] should be ~gamma=1.0, got {}",
normed[0]
);
}
}