use crate::distance::{cosine_similarity_fast, norm_fast};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ConservationMetrics {
pub magnitude: f32,
pub energy: f32,
pub information: f32,
}
impl ConservationMetrics {
pub fn compute(embeddings: &[&[f32]], attention: &[f32]) -> Self {
assert_eq!(embeddings.len(), attention.len(), "Embeddings and attention must have same length");
if embeddings.is_empty() {
return Self {
magnitude: 0.0,
energy: 0.0,
information: 0.0,
};
}
let magnitude: f32 = embeddings.iter()
.zip(attention.iter())
.map(|(e, a)| a * norm_fast(e))
.sum();
let mut energy = 0.0_f32;
for (i, ei) in embeddings.iter().enumerate() {
for (j, ej) in embeddings.iter().enumerate() {
energy += attention[i] * attention[j] * cosine_similarity_fast(ei, ej);
}
}
energy *= 0.5;
let information: f32 = -attention.iter()
.filter(|&&a| a > 1e-10)
.map(|a| a * a.ln())
.sum::<f32>();
Self {
magnitude,
energy,
information,
}
}
pub fn from_vecs(embeddings: &[Vec<f32>], attention: &[f32]) -> Self {
let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
Self::compute(&refs, attention)
}
#[inline]
pub fn is_conserved(&self, other: &Self, tolerance: f32) -> bool {
(self.magnitude - other.magnitude).abs() < tolerance
&& (self.energy - other.energy).abs() < tolerance
}
#[inline]
pub fn is_fully_conserved(&self, other: &Self, tolerance: f32) -> bool {
self.is_conserved(other, tolerance)
&& (self.information - other.information).abs() < tolerance
}
pub fn violation(&self, other: &Self) -> ConservationViolation {
ConservationViolation {
magnitude_delta: (self.magnitude - other.magnitude).abs(),
energy_delta: (self.energy - other.energy).abs(),
information_delta: (self.information - other.information).abs(),
}
}
pub fn uniform(embeddings: &[&[f32]]) -> Self {
if embeddings.is_empty() {
return Self {
magnitude: 0.0,
energy: 0.0,
information: 0.0,
};
}
let n = embeddings.len();
let attention: Vec<f32> = vec![1.0 / n as f32; n];
Self::compute(embeddings, &attention)
}
#[inline]
pub fn max_entropy(n: usize) -> f32 {
if n <= 1 {
0.0
} else {
(n as f32).ln()
}
}
#[inline]
pub fn normalized_entropy(&self, n: usize) -> f32 {
let max = Self::max_entropy(n);
if max > 0.0 {
self.information / max
} else {
0.0
}
}
}
impl Default for ConservationMetrics {
fn default() -> Self {
Self {
magnitude: 0.0,
energy: 0.0,
information: 0.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ConservationViolation {
pub magnitude_delta: f32,
pub energy_delta: f32,
pub information_delta: f32,
}
impl ConservationViolation {
#[inline]
pub fn total(&self) -> f32 {
self.magnitude_delta + self.energy_delta + self.information_delta
}
#[inline]
pub fn max(&self) -> f32 {
self.magnitude_delta
.max(self.energy_delta)
.max(self.information_delta)
}
#[inline]
pub fn is_acceptable(&self, tolerance: f32) -> bool {
self.max() < tolerance
}
}
#[derive(Debug, Clone)]
pub struct ConservationConfig {
pub magnitude_tolerance: f32,
pub energy_tolerance: f32,
pub information_tolerance: f32,
pub strict: bool,
}
impl Default for ConservationConfig {
fn default() -> Self {
Self {
magnitude_tolerance: 0.01,
energy_tolerance: 0.01,
information_tolerance: 0.1, strict: false,
}
}
}
impl ConservationConfig {
pub fn strict() -> Self {
Self {
magnitude_tolerance: 0.001,
energy_tolerance: 0.001,
information_tolerance: 0.01,
strict: true,
}
}
pub fn is_acceptable(&self, violation: &ConservationViolation) -> bool {
violation.magnitude_delta < self.magnitude_tolerance
&& violation.energy_delta < self.energy_tolerance
&& violation.information_delta < self.information_tolerance
}
}
#[derive(Debug, Clone)]
pub struct ConservationTracker {
history: Vec<ConservationMetrics>,
config: ConservationConfig,
}
impl ConservationTracker {
pub fn new(config: ConservationConfig) -> Self {
Self {
history: Vec::new(),
config,
}
}
pub fn record(&mut self, metrics: ConservationMetrics) {
self.history.push(metrics);
}
pub fn current(&self) -> Option<&ConservationMetrics> {
self.history.last()
}
pub fn initial(&self) -> Option<&ConservationMetrics> {
self.history.first()
}
pub fn is_conserved_from_initial(&self) -> Option<bool> {
let initial = self.initial()?;
let current = self.current()?;
Some(self.config.is_acceptable(&initial.violation(current)))
}
pub fn total_drift(&self) -> Option<ConservationViolation> {
let initial = self.initial()?;
let current = self.current()?;
Some(initial.violation(current))
}
pub fn history(&self) -> &[ConservationMetrics] {
&self.history
}
pub fn clear(&mut self) {
self.history.clear();
}
}
pub fn weighted_centroid(embeddings: &[&[f32]], attention: &[f32]) -> Vec<f32> {
if embeddings.is_empty() {
return Vec::new();
}
let dim = embeddings[0].len();
let mut centroid = vec![0.0_f32; dim];
for (e, &a) in embeddings.iter().zip(attention.iter()) {
for (c, &v) in centroid.iter_mut().zip(e.iter()) {
*c += a * v;
}
}
centroid
}
pub fn weighted_covariance(embeddings: &[&[f32]], attention: &[f32]) -> Vec<f32> {
if embeddings.is_empty() {
return Vec::new();
}
let dim = embeddings[0].len();
let centroid = weighted_centroid(embeddings, attention);
let n_cov = (dim * (dim + 1)) / 2;
let mut cov = vec![0.0_f32; n_cov];
for (e, &a) in embeddings.iter().zip(attention.iter()) {
let mut idx = 0;
for i in 0..dim {
for j in i..dim {
let diff_i = e[i] - centroid[i];
let diff_j = e[j] - centroid[j];
cov[idx] += a * diff_i * diff_j;
idx += 1;
}
}
}
cov
}
#[cfg(test)]
mod tests {
use super::*;
fn make_embeddings() -> Vec<Vec<f32>> {
vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
]
}
#[test]
fn test_compute_metrics() {
let embeddings = make_embeddings();
let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
let attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let metrics = ConservationMetrics::compute(&refs, &attention);
assert!((metrics.magnitude - 1.0).abs() < 1e-5);
let expected_info = 3.0_f32.ln();
assert!((metrics.information - expected_info).abs() < 1e-5);
}
#[test]
fn test_is_conserved() {
let embeddings = make_embeddings();
let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
let attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let m1 = ConservationMetrics::compute(&refs, &attention);
let m2 = ConservationMetrics::compute(&refs, &attention);
assert!(m1.is_conserved(&m2, 0.01));
}
#[test]
fn test_conservation_violation() {
let m1 = ConservationMetrics {
magnitude: 1.0,
energy: 0.5,
information: 1.0,
};
let m2 = ConservationMetrics {
magnitude: 1.1,
energy: 0.6,
information: 0.9,
};
let violation = m1.violation(&m2);
assert!((violation.magnitude_delta - 0.1).abs() < 1e-5);
assert!((violation.energy_delta - 0.1).abs() < 1e-5);
assert!((violation.information_delta - 0.1).abs() < 1e-5);
}
#[test]
fn test_max_entropy() {
assert!((ConservationMetrics::max_entropy(1) - 0.0).abs() < 1e-5);
assert!((ConservationMetrics::max_entropy(2) - 2.0_f32.ln()).abs() < 1e-5);
assert!((ConservationMetrics::max_entropy(10) - 10.0_f32.ln()).abs() < 1e-5);
}
#[test]
fn test_normalized_entropy() {
let embeddings = make_embeddings();
let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
let uniform_attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
let uniform_metrics = ConservationMetrics::compute(&refs, &uniform_attention);
assert!((uniform_metrics.normalized_entropy(3) - 1.0).abs() < 1e-5);
let concentrated = vec![0.9, 0.05, 0.05];
let concentrated_metrics = ConservationMetrics::compute(&refs, &concentrated);
assert!(concentrated_metrics.normalized_entropy(3) < 0.5);
}
#[test]
fn test_tracker() {
let config = ConservationConfig::default();
let mut tracker = ConservationTracker::new(config);
let m1 = ConservationMetrics {
magnitude: 1.0,
energy: 0.5,
information: 1.0,
};
let m2 = ConservationMetrics {
magnitude: 1.001,
energy: 0.501,
information: 1.01,
};
tracker.record(m1);
tracker.record(m2);
assert!(tracker.is_conserved_from_initial().unwrap());
assert_eq!(tracker.history().len(), 2);
}
#[test]
fn test_weighted_centroid() {
let embeddings = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
];
let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
let attention = vec![0.5, 0.5];
let centroid = weighted_centroid(&refs, &attention);
assert!((centroid[0] - 0.5).abs() < 1e-5);
assert!((centroid[1] - 0.5).abs() < 1e-5);
let attention2 = vec![0.8, 0.2];
let centroid2 = weighted_centroid(&refs, &attention2);
assert!((centroid2[0] - 0.8).abs() < 1e-5);
assert!((centroid2[1] - 0.2).abs() < 1e-5);
}
}