rag_plusplus_core/trajectory/
conservation.rs

1//! Conservation Metrics for Bounded Forgetting
2//!
3//! Provides metrics to validate that memory operations preserve important
4//! properties. Inspired by RCP's conservation laws.
5//!
6//! # Conservation Laws
7//!
8//! | Law | Formula | Meaning |
9//! |-----|---------|---------|
10//! | Magnitude | ‖C‖ = const | Context doesn't disappear |
11//! | Energy | E = ½Σᵢⱼ aᵢaⱼ cos(eᵢ, eⱼ) | Attention capacity conserved |
12//! | Information | H = -Σ aᵢ log(aᵢ) | Shannon entropy of attention |
13//!
14//! # Usage
15//!
16//! ```
17//! use rag_plusplus_core::trajectory::ConservationMetrics;
18//!
19//! // Example embeddings (3 vectors of dimension 4)
20//! let embeddings_before: Vec<&[f32]> = vec![
21//!     &[1.0, 0.0, 0.0, 0.0],
22//!     &[0.0, 1.0, 0.0, 0.0],
23//!     &[0.0, 0.0, 1.0, 0.0],
24//! ];
25//! let attention_before = vec![0.5, 0.3, 0.2];
26//!
27//! let embeddings_after: Vec<&[f32]> = vec![
28//!     &[0.9, 0.1, 0.0, 0.0],
29//!     &[0.1, 0.9, 0.0, 0.0],
30//!     &[0.0, 0.0, 1.0, 0.0],
31//! ];
32//! let attention_after = vec![0.5, 0.3, 0.2];
33//!
34//! // Compute metrics before and after an operation
35//! let before = ConservationMetrics::compute(&embeddings_before, &attention_before);
36//! let after = ConservationMetrics::compute(&embeddings_after, &attention_after);
37//!
38//! // Check if conservation is preserved
39//! if before.is_conserved(&after, 0.1) {
40//!     println!("Operation preserved conservation laws");
41//! } else {
42//!     println!("Warning: Conservation violation detected");
43//! }
44//! ```
45
46use crate::distance::{cosine_similarity_fast, norm_fast};
47
48/// Conservation metrics for a set of embeddings with attention weights.
49#[derive(Debug, Clone, Copy, PartialEq)]
50pub struct ConservationMetrics {
51    /// Total weighted magnitude: Σᵢ aᵢ ‖eᵢ‖
52    pub magnitude: f32,
53    /// Attention energy: ½Σᵢⱼ aᵢaⱼ cos(eᵢ, eⱼ)
54    pub energy: f32,
55    /// Shannon entropy: -Σᵢ aᵢ log(aᵢ)
56    pub information: f32,
57}
58
59impl ConservationMetrics {
60    /// Compute conservation metrics for a set of embeddings and attention weights.
61    ///
62    /// # Arguments
63    ///
64    /// * `embeddings` - Slice of embedding slices
65    /// * `attention` - Attention weights for each embedding (must sum to 1)
66    ///
67    /// # Returns
68    ///
69    /// ConservationMetrics with magnitude, energy, and information.
70    pub fn compute(embeddings: &[&[f32]], attention: &[f32]) -> Self {
71        assert_eq!(embeddings.len(), attention.len(), "Embeddings and attention must have same length");
72
73        if embeddings.is_empty() {
74            return Self {
75                magnitude: 0.0,
76                energy: 0.0,
77                information: 0.0,
78            };
79        }
80
81        // Magnitude: Σᵢ aᵢ ‖eᵢ‖
82        let magnitude: f32 = embeddings.iter()
83            .zip(attention.iter())
84            .map(|(e, a)| a * norm_fast(e))
85            .sum();
86
87        // Energy: ½Σᵢⱼ aᵢaⱼ cos(eᵢ, eⱼ)
88        let mut energy = 0.0_f32;
89        for (i, ei) in embeddings.iter().enumerate() {
90            for (j, ej) in embeddings.iter().enumerate() {
91                energy += attention[i] * attention[j] * cosine_similarity_fast(ei, ej);
92            }
93        }
94        energy *= 0.5;
95
96        // Information: -Σᵢ aᵢ log(aᵢ)
97        let information: f32 = -attention.iter()
98            .filter(|&&a| a > 1e-10)
99            .map(|a| a * a.ln())
100            .sum::<f32>();
101
102        Self {
103            magnitude,
104            energy,
105            information,
106        }
107    }
108
109    /// Compute metrics from owned vectors.
110    pub fn from_vecs(embeddings: &[Vec<f32>], attention: &[f32]) -> Self {
111        let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
112        Self::compute(&refs, attention)
113    }
114
115    /// Check if conservation is preserved within tolerance.
116    ///
117    /// # Arguments
118    ///
119    /// * `other` - Metrics after an operation
120    /// * `tolerance` - Maximum allowed difference
121    ///
122    /// # Returns
123    ///
124    /// True if magnitude and energy are conserved within tolerance.
125    #[inline]
126    pub fn is_conserved(&self, other: &Self, tolerance: f32) -> bool {
127        (self.magnitude - other.magnitude).abs() < tolerance
128            && (self.energy - other.energy).abs() < tolerance
129    }
130
131    /// Check if all three metrics are conserved.
132    #[inline]
133    pub fn is_fully_conserved(&self, other: &Self, tolerance: f32) -> bool {
134        self.is_conserved(other, tolerance)
135            && (self.information - other.information).abs() < tolerance
136    }
137
138    /// Compute the conservation violation (distance from conservation).
139    pub fn violation(&self, other: &Self) -> ConservationViolation {
140        ConservationViolation {
141            magnitude_delta: (self.magnitude - other.magnitude).abs(),
142            energy_delta: (self.energy - other.energy).abs(),
143            information_delta: (self.information - other.information).abs(),
144        }
145    }
146
147    /// Create metrics for a uniform attention distribution.
148    pub fn uniform(embeddings: &[&[f32]]) -> Self {
149        if embeddings.is_empty() {
150            return Self {
151                magnitude: 0.0,
152                energy: 0.0,
153                information: 0.0,
154            };
155        }
156
157        let n = embeddings.len();
158        let attention: Vec<f32> = vec![1.0 / n as f32; n];
159        Self::compute(embeddings, &attention)
160    }
161
162    /// Maximum possible entropy for n items.
163    #[inline]
164    pub fn max_entropy(n: usize) -> f32 {
165        if n <= 1 {
166            0.0
167        } else {
168            (n as f32).ln()
169        }
170    }
171
172    /// Normalized entropy (0 = concentrated, 1 = uniform).
173    #[inline]
174    pub fn normalized_entropy(&self, n: usize) -> f32 {
175        let max = Self::max_entropy(n);
176        if max > 0.0 {
177            self.information / max
178        } else {
179            0.0
180        }
181    }
182}
183
184impl Default for ConservationMetrics {
185    fn default() -> Self {
186        Self {
187            magnitude: 0.0,
188            energy: 0.0,
189            information: 0.0,
190        }
191    }
192}
193
194/// Details of a conservation violation.
195#[derive(Debug, Clone, Copy, PartialEq)]
196pub struct ConservationViolation {
197    /// Absolute difference in magnitude
198    pub magnitude_delta: f32,
199    /// Absolute difference in energy
200    pub energy_delta: f32,
201    /// Absolute difference in information
202    pub information_delta: f32,
203}
204
205impl ConservationViolation {
206    /// Total violation magnitude (L1 norm).
207    #[inline]
208    pub fn total(&self) -> f32 {
209        self.magnitude_delta + self.energy_delta + self.information_delta
210    }
211
212    /// Maximum single violation.
213    #[inline]
214    pub fn max(&self) -> f32 {
215        self.magnitude_delta
216            .max(self.energy_delta)
217            .max(self.information_delta)
218    }
219
220    /// Check if violation is within tolerance.
221    #[inline]
222    pub fn is_acceptable(&self, tolerance: f32) -> bool {
223        self.max() < tolerance
224    }
225}
226
227/// Configuration for conservation checking.
228#[derive(Debug, Clone)]
229pub struct ConservationConfig {
230    /// Tolerance for magnitude conservation
231    pub magnitude_tolerance: f32,
232    /// Tolerance for energy conservation
233    pub energy_tolerance: f32,
234    /// Tolerance for information conservation
235    pub information_tolerance: f32,
236    /// Whether to enforce conservation (fail on violation)
237    pub strict: bool,
238}
239
240impl Default for ConservationConfig {
241    fn default() -> Self {
242        Self {
243            magnitude_tolerance: 0.01,
244            energy_tolerance: 0.01,
245            information_tolerance: 0.1, // Information can vary more
246            strict: false,
247        }
248    }
249}
250
251impl ConservationConfig {
252    /// Create a strict configuration.
253    pub fn strict() -> Self {
254        Self {
255            magnitude_tolerance: 0.001,
256            energy_tolerance: 0.001,
257            information_tolerance: 0.01,
258            strict: true,
259        }
260    }
261
262    /// Check if violation is acceptable according to config.
263    pub fn is_acceptable(&self, violation: &ConservationViolation) -> bool {
264        violation.magnitude_delta < self.magnitude_tolerance
265            && violation.energy_delta < self.energy_tolerance
266            && violation.information_delta < self.information_tolerance
267    }
268}
269
270/// Track conservation over time.
271#[derive(Debug, Clone)]
272pub struct ConservationTracker {
273    history: Vec<ConservationMetrics>,
274    config: ConservationConfig,
275}
276
277impl ConservationTracker {
278    /// Create a new tracker.
279    pub fn new(config: ConservationConfig) -> Self {
280        Self {
281            history: Vec::new(),
282            config,
283        }
284    }
285
286    /// Record a conservation snapshot.
287    pub fn record(&mut self, metrics: ConservationMetrics) {
288        self.history.push(metrics);
289    }
290
291    /// Get the most recent metrics.
292    pub fn current(&self) -> Option<&ConservationMetrics> {
293        self.history.last()
294    }
295
296    /// Get the initial metrics.
297    pub fn initial(&self) -> Option<&ConservationMetrics> {
298        self.history.first()
299    }
300
301    /// Check if conservation is maintained from initial state.
302    pub fn is_conserved_from_initial(&self) -> Option<bool> {
303        let initial = self.initial()?;
304        let current = self.current()?;
305        Some(self.config.is_acceptable(&initial.violation(current)))
306    }
307
308    /// Get the total drift from initial state.
309    pub fn total_drift(&self) -> Option<ConservationViolation> {
310        let initial = self.initial()?;
311        let current = self.current()?;
312        Some(initial.violation(current))
313    }
314
315    /// Get all recorded metrics.
316    pub fn history(&self) -> &[ConservationMetrics] {
317        &self.history
318    }
319
320    /// Clear history.
321    pub fn clear(&mut self) {
322        self.history.clear();
323    }
324}
325
326/// Compute the attention-weighted centroid of embeddings.
327///
328/// Useful for checking if the "center of mass" is preserved.
329pub fn weighted_centroid(embeddings: &[&[f32]], attention: &[f32]) -> Vec<f32> {
330    if embeddings.is_empty() {
331        return Vec::new();
332    }
333
334    let dim = embeddings[0].len();
335    let mut centroid = vec![0.0_f32; dim];
336
337    for (e, &a) in embeddings.iter().zip(attention.iter()) {
338        for (c, &v) in centroid.iter_mut().zip(e.iter()) {
339            *c += a * v;
340        }
341    }
342
343    centroid
344}
345
346/// Compute the attention-weighted covariance matrix (flattened).
347///
348/// Returns the upper triangle of the covariance matrix as a flat vector.
349pub fn weighted_covariance(embeddings: &[&[f32]], attention: &[f32]) -> Vec<f32> {
350    if embeddings.is_empty() {
351        return Vec::new();
352    }
353
354    let dim = embeddings[0].len();
355    let centroid = weighted_centroid(embeddings, attention);
356
357    // Compute upper triangle of covariance
358    let n_cov = (dim * (dim + 1)) / 2;
359    let mut cov = vec![0.0_f32; n_cov];
360
361    for (e, &a) in embeddings.iter().zip(attention.iter()) {
362        let mut idx = 0;
363        for i in 0..dim {
364            for j in i..dim {
365                let diff_i = e[i] - centroid[i];
366                let diff_j = e[j] - centroid[j];
367                cov[idx] += a * diff_i * diff_j;
368                idx += 1;
369            }
370        }
371    }
372
373    cov
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    fn make_embeddings() -> Vec<Vec<f32>> {
381        vec![
382            vec![1.0, 0.0, 0.0],
383            vec![0.0, 1.0, 0.0],
384            vec![0.0, 0.0, 1.0],
385        ]
386    }
387
388    #[test]
389    fn test_compute_metrics() {
390        let embeddings = make_embeddings();
391        let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
392        let attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
393
394        let metrics = ConservationMetrics::compute(&refs, &attention);
395
396        // Magnitude: (1/3) * 1.0 * 3 = 1.0
397        assert!((metrics.magnitude - 1.0).abs() < 1e-5);
398
399        // Information: -3 * (1/3) * ln(1/3) = ln(3)
400        let expected_info = 3.0_f32.ln();
401        assert!((metrics.information - expected_info).abs() < 1e-5);
402    }
403
404    #[test]
405    fn test_is_conserved() {
406        let embeddings = make_embeddings();
407        let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
408        let attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
409
410        let m1 = ConservationMetrics::compute(&refs, &attention);
411        let m2 = ConservationMetrics::compute(&refs, &attention);
412
413        assert!(m1.is_conserved(&m2, 0.01));
414    }
415
416    #[test]
417    fn test_conservation_violation() {
418        let m1 = ConservationMetrics {
419            magnitude: 1.0,
420            energy: 0.5,
421            information: 1.0,
422        };
423
424        let m2 = ConservationMetrics {
425            magnitude: 1.1,
426            energy: 0.6,
427            information: 0.9,
428        };
429
430        let violation = m1.violation(&m2);
431        assert!((violation.magnitude_delta - 0.1).abs() < 1e-5);
432        assert!((violation.energy_delta - 0.1).abs() < 1e-5);
433        assert!((violation.information_delta - 0.1).abs() < 1e-5);
434    }
435
436    #[test]
437    fn test_max_entropy() {
438        // ln(1) = 0
439        assert!((ConservationMetrics::max_entropy(1) - 0.0).abs() < 1e-5);
440
441        // ln(2)
442        assert!((ConservationMetrics::max_entropy(2) - 2.0_f32.ln()).abs() < 1e-5);
443
444        // ln(10)
445        assert!((ConservationMetrics::max_entropy(10) - 10.0_f32.ln()).abs() < 1e-5);
446    }
447
448    #[test]
449    fn test_normalized_entropy() {
450        let embeddings = make_embeddings();
451        let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
452
453        // Uniform distribution should have normalized entropy = 1
454        let uniform_attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
455        let uniform_metrics = ConservationMetrics::compute(&refs, &uniform_attention);
456        assert!((uniform_metrics.normalized_entropy(3) - 1.0).abs() < 1e-5);
457
458        // Concentrated distribution should have lower normalized entropy
459        let concentrated = vec![0.9, 0.05, 0.05];
460        let concentrated_metrics = ConservationMetrics::compute(&refs, &concentrated);
461        assert!(concentrated_metrics.normalized_entropy(3) < 0.5);
462    }
463
464    #[test]
465    fn test_tracker() {
466        let config = ConservationConfig::default();
467        let mut tracker = ConservationTracker::new(config);
468
469        let m1 = ConservationMetrics {
470            magnitude: 1.0,
471            energy: 0.5,
472            information: 1.0,
473        };
474
475        let m2 = ConservationMetrics {
476            magnitude: 1.001,
477            energy: 0.501,
478            information: 1.01,
479        };
480
481        tracker.record(m1);
482        tracker.record(m2);
483
484        assert!(tracker.is_conserved_from_initial().unwrap());
485        assert_eq!(tracker.history().len(), 2);
486    }
487
488    #[test]
489    fn test_weighted_centroid() {
490        let embeddings = vec![
491            vec![1.0, 0.0],
492            vec![0.0, 1.0],
493        ];
494        let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
495
496        // Equal weights
497        let attention = vec![0.5, 0.5];
498        let centroid = weighted_centroid(&refs, &attention);
499        assert!((centroid[0] - 0.5).abs() < 1e-5);
500        assert!((centroid[1] - 0.5).abs() < 1e-5);
501
502        // Unequal weights
503        let attention2 = vec![0.8, 0.2];
504        let centroid2 = weighted_centroid(&refs, &attention2);
505        assert!((centroid2[0] - 0.8).abs() < 1e-5);
506        assert!((centroid2[1] - 0.2).abs() < 1e-5);
507    }
508}