use super::kl_divergence::{DiagonalGaussian, KLDivergence};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IBConfig {
pub bottleneck_dim: usize,
pub beta: f32,
pub min_var: f32,
pub reparameterize: bool,
}
impl Default for IBConfig {
fn default() -> Self {
Self {
bottleneck_dim: 64,
beta: 1e-3,
min_var: 1e-4,
reparameterize: true,
}
}
}
#[derive(Debug, Clone)]
pub struct InformationBottleneck {
config: IBConfig,
}
impl InformationBottleneck {
pub fn new(config: IBConfig) -> Self {
Self { config }
}
pub fn compute_kl_loss(&self, mean: &[f32], log_var: &[f32]) -> f32 {
let kl = KLDivergence::gaussian_to_unit_arrays(mean, log_var);
self.config.beta * kl
}
pub fn compute_kl_loss_gaussian(&self, gaussian: &DiagonalGaussian) -> f32 {
let kl = KLDivergence::gaussian_to_unit(gaussian);
self.config.beta * kl
}
pub fn sample(&self, mean: &[f32], log_var: &[f32], epsilon: &[f32]) -> Vec<f32> {
let n = mean.len().min(log_var.len()).min(epsilon.len());
let mut z = vec![0.0f32; n];
for i in 0..n {
let lv = log_var[i].max(self.config.min_var.ln());
let std = (0.5 * lv.clamp(-20.0, 20.0)).exp();
z[i] = mean[i] + std * epsilon[i];
}
z
}
pub fn kl_gradients(&self, mean: &[f32], log_var: &[f32]) -> (Vec<f32>, Vec<f32>) {
let n = mean.len().min(log_var.len());
let mut d_mean = vec![0.0f32; n];
let mut d_log_var = vec![0.0f32; n];
for i in 0..n {
d_mean[i] = self.config.beta * mean[i];
let lv_clamped = log_var[i].clamp(-20.0, 20.0);
d_log_var[i] = self.config.beta * 0.5 * (lv_clamped.exp() - 1.0);
}
(d_mean, d_log_var)
}
pub fn compress_attention_weights(&self, weights: &[f32], temperature: f32) -> (Vec<f32>, f32) {
let n = weights.len();
let entropy = self.compute_entropy(weights);
let uniform_entropy = (n as f32).ln();
let kl = (uniform_entropy - entropy).max(0.0);
let mut compressed = weights.to_vec();
for w in compressed.iter_mut() {
*w = (*w).powf(1.0 / temperature.max(0.1));
}
let sum: f32 = compressed.iter().sum();
if sum > 0.0 {
for w in compressed.iter_mut() {
*w /= sum;
}
}
(compressed, self.config.beta * kl)
}
fn compute_entropy(&self, weights: &[f32]) -> f32 {
let eps = 1e-10;
let mut entropy = 0.0f32;
for &w in weights {
if w > eps {
entropy -= w * w.ln();
}
}
entropy.max(0.0)
}
pub fn set_beta(&mut self, beta: f32) {
self.config.beta = beta.max(0.0);
}
pub fn beta(&self) -> f32 {
self.config.beta
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ib_kl_loss() {
let ib = InformationBottleneck::new(IBConfig::default());
let mean = vec![0.0; 16];
let log_var = vec![0.0; 16];
let loss = ib.compute_kl_loss(&mean, &log_var);
assert!(loss.abs() < 1e-5);
}
#[test]
fn test_ib_sample() {
let ib = InformationBottleneck::new(IBConfig::default());
let mean = vec![1.0, 2.0];
let log_var = vec![0.0, 0.0];
let epsilon = vec![0.0, 0.0];
let z = ib.sample(&mean, &log_var, &epsilon);
assert!((z[0] - 1.0).abs() < 1e-5);
assert!((z[1] - 2.0).abs() < 1e-5);
}
#[test]
fn test_kl_gradients() {
let ib = InformationBottleneck::new(IBConfig {
beta: 1.0,
..Default::default()
});
let mean = vec![1.0, 0.0];
let log_var = vec![0.0, 0.0];
let (d_mean, d_log_var) = ib.kl_gradients(&mean, &log_var);
assert!((d_mean[0] - 1.0).abs() < 1e-5);
assert!((d_mean[1] - 0.0).abs() < 1e-5);
assert!((d_log_var[0] - 0.0).abs() < 1e-5);
}
#[test]
fn test_compress_weights() {
let ib = InformationBottleneck::new(IBConfig::default());
let weights = vec![0.7, 0.2, 0.1];
let (compressed, kl) = ib.compress_attention_weights(&weights, 1.0);
assert_eq!(compressed.len(), 3);
assert!(kl >= 0.0);
let sum: f32 = compressed.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
}