rag_plusplus_core/trajectory/
conservation.rs1use crate::distance::{cosine_similarity_fast, norm_fast};
47
48#[derive(Debug, Clone, Copy, PartialEq)]
50pub struct ConservationMetrics {
51 pub magnitude: f32,
53 pub energy: f32,
55 pub information: f32,
57}
58
59impl ConservationMetrics {
60 pub fn compute(embeddings: &[&[f32]], attention: &[f32]) -> Self {
71 assert_eq!(embeddings.len(), attention.len(), "Embeddings and attention must have same length");
72
73 if embeddings.is_empty() {
74 return Self {
75 magnitude: 0.0,
76 energy: 0.0,
77 information: 0.0,
78 };
79 }
80
81 let magnitude: f32 = embeddings.iter()
83 .zip(attention.iter())
84 .map(|(e, a)| a * norm_fast(e))
85 .sum();
86
87 let mut energy = 0.0_f32;
89 for (i, ei) in embeddings.iter().enumerate() {
90 for (j, ej) in embeddings.iter().enumerate() {
91 energy += attention[i] * attention[j] * cosine_similarity_fast(ei, ej);
92 }
93 }
94 energy *= 0.5;
95
96 let information: f32 = -attention.iter()
98 .filter(|&&a| a > 1e-10)
99 .map(|a| a * a.ln())
100 .sum::<f32>();
101
102 Self {
103 magnitude,
104 energy,
105 information,
106 }
107 }
108
109 pub fn from_vecs(embeddings: &[Vec<f32>], attention: &[f32]) -> Self {
111 let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
112 Self::compute(&refs, attention)
113 }
114
115 #[inline]
126 pub fn is_conserved(&self, other: &Self, tolerance: f32) -> bool {
127 (self.magnitude - other.magnitude).abs() < tolerance
128 && (self.energy - other.energy).abs() < tolerance
129 }
130
131 #[inline]
133 pub fn is_fully_conserved(&self, other: &Self, tolerance: f32) -> bool {
134 self.is_conserved(other, tolerance)
135 && (self.information - other.information).abs() < tolerance
136 }
137
138 pub fn violation(&self, other: &Self) -> ConservationViolation {
140 ConservationViolation {
141 magnitude_delta: (self.magnitude - other.magnitude).abs(),
142 energy_delta: (self.energy - other.energy).abs(),
143 information_delta: (self.information - other.information).abs(),
144 }
145 }
146
147 pub fn uniform(embeddings: &[&[f32]]) -> Self {
149 if embeddings.is_empty() {
150 return Self {
151 magnitude: 0.0,
152 energy: 0.0,
153 information: 0.0,
154 };
155 }
156
157 let n = embeddings.len();
158 let attention: Vec<f32> = vec![1.0 / n as f32; n];
159 Self::compute(embeddings, &attention)
160 }
161
162 #[inline]
164 pub fn max_entropy(n: usize) -> f32 {
165 if n <= 1 {
166 0.0
167 } else {
168 (n as f32).ln()
169 }
170 }
171
172 #[inline]
174 pub fn normalized_entropy(&self, n: usize) -> f32 {
175 let max = Self::max_entropy(n);
176 if max > 0.0 {
177 self.information / max
178 } else {
179 0.0
180 }
181 }
182}
183
184impl Default for ConservationMetrics {
185 fn default() -> Self {
186 Self {
187 magnitude: 0.0,
188 energy: 0.0,
189 information: 0.0,
190 }
191 }
192}
193
194#[derive(Debug, Clone, Copy, PartialEq)]
196pub struct ConservationViolation {
197 pub magnitude_delta: f32,
199 pub energy_delta: f32,
201 pub information_delta: f32,
203}
204
205impl ConservationViolation {
206 #[inline]
208 pub fn total(&self) -> f32 {
209 self.magnitude_delta + self.energy_delta + self.information_delta
210 }
211
212 #[inline]
214 pub fn max(&self) -> f32 {
215 self.magnitude_delta
216 .max(self.energy_delta)
217 .max(self.information_delta)
218 }
219
220 #[inline]
222 pub fn is_acceptable(&self, tolerance: f32) -> bool {
223 self.max() < tolerance
224 }
225}
226
227#[derive(Debug, Clone)]
229pub struct ConservationConfig {
230 pub magnitude_tolerance: f32,
232 pub energy_tolerance: f32,
234 pub information_tolerance: f32,
236 pub strict: bool,
238}
239
240impl Default for ConservationConfig {
241 fn default() -> Self {
242 Self {
243 magnitude_tolerance: 0.01,
244 energy_tolerance: 0.01,
245 information_tolerance: 0.1, strict: false,
247 }
248 }
249}
250
251impl ConservationConfig {
252 pub fn strict() -> Self {
254 Self {
255 magnitude_tolerance: 0.001,
256 energy_tolerance: 0.001,
257 information_tolerance: 0.01,
258 strict: true,
259 }
260 }
261
262 pub fn is_acceptable(&self, violation: &ConservationViolation) -> bool {
264 violation.magnitude_delta < self.magnitude_tolerance
265 && violation.energy_delta < self.energy_tolerance
266 && violation.information_delta < self.information_tolerance
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct ConservationTracker {
273 history: Vec<ConservationMetrics>,
274 config: ConservationConfig,
275}
276
277impl ConservationTracker {
278 pub fn new(config: ConservationConfig) -> Self {
280 Self {
281 history: Vec::new(),
282 config,
283 }
284 }
285
286 pub fn record(&mut self, metrics: ConservationMetrics) {
288 self.history.push(metrics);
289 }
290
291 pub fn current(&self) -> Option<&ConservationMetrics> {
293 self.history.last()
294 }
295
296 pub fn initial(&self) -> Option<&ConservationMetrics> {
298 self.history.first()
299 }
300
301 pub fn is_conserved_from_initial(&self) -> Option<bool> {
303 let initial = self.initial()?;
304 let current = self.current()?;
305 Some(self.config.is_acceptable(&initial.violation(current)))
306 }
307
308 pub fn total_drift(&self) -> Option<ConservationViolation> {
310 let initial = self.initial()?;
311 let current = self.current()?;
312 Some(initial.violation(current))
313 }
314
315 pub fn history(&self) -> &[ConservationMetrics] {
317 &self.history
318 }
319
320 pub fn clear(&mut self) {
322 self.history.clear();
323 }
324}
325
326pub fn weighted_centroid(embeddings: &[&[f32]], attention: &[f32]) -> Vec<f32> {
330 if embeddings.is_empty() {
331 return Vec::new();
332 }
333
334 let dim = embeddings[0].len();
335 let mut centroid = vec![0.0_f32; dim];
336
337 for (e, &a) in embeddings.iter().zip(attention.iter()) {
338 for (c, &v) in centroid.iter_mut().zip(e.iter()) {
339 *c += a * v;
340 }
341 }
342
343 centroid
344}
345
346pub fn weighted_covariance(embeddings: &[&[f32]], attention: &[f32]) -> Vec<f32> {
350 if embeddings.is_empty() {
351 return Vec::new();
352 }
353
354 let dim = embeddings[0].len();
355 let centroid = weighted_centroid(embeddings, attention);
356
357 let n_cov = (dim * (dim + 1)) / 2;
359 let mut cov = vec![0.0_f32; n_cov];
360
361 for (e, &a) in embeddings.iter().zip(attention.iter()) {
362 let mut idx = 0;
363 for i in 0..dim {
364 for j in i..dim {
365 let diff_i = e[i] - centroid[i];
366 let diff_j = e[j] - centroid[j];
367 cov[idx] += a * diff_i * diff_j;
368 idx += 1;
369 }
370 }
371 }
372
373 cov
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 fn make_embeddings() -> Vec<Vec<f32>> {
381 vec![
382 vec![1.0, 0.0, 0.0],
383 vec![0.0, 1.0, 0.0],
384 vec![0.0, 0.0, 1.0],
385 ]
386 }
387
388 #[test]
389 fn test_compute_metrics() {
390 let embeddings = make_embeddings();
391 let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
392 let attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
393
394 let metrics = ConservationMetrics::compute(&refs, &attention);
395
396 assert!((metrics.magnitude - 1.0).abs() < 1e-5);
398
399 let expected_info = 3.0_f32.ln();
401 assert!((metrics.information - expected_info).abs() < 1e-5);
402 }
403
404 #[test]
405 fn test_is_conserved() {
406 let embeddings = make_embeddings();
407 let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
408 let attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
409
410 let m1 = ConservationMetrics::compute(&refs, &attention);
411 let m2 = ConservationMetrics::compute(&refs, &attention);
412
413 assert!(m1.is_conserved(&m2, 0.01));
414 }
415
416 #[test]
417 fn test_conservation_violation() {
418 let m1 = ConservationMetrics {
419 magnitude: 1.0,
420 energy: 0.5,
421 information: 1.0,
422 };
423
424 let m2 = ConservationMetrics {
425 magnitude: 1.1,
426 energy: 0.6,
427 information: 0.9,
428 };
429
430 let violation = m1.violation(&m2);
431 assert!((violation.magnitude_delta - 0.1).abs() < 1e-5);
432 assert!((violation.energy_delta - 0.1).abs() < 1e-5);
433 assert!((violation.information_delta - 0.1).abs() < 1e-5);
434 }
435
436 #[test]
437 fn test_max_entropy() {
438 assert!((ConservationMetrics::max_entropy(1) - 0.0).abs() < 1e-5);
440
441 assert!((ConservationMetrics::max_entropy(2) - 2.0_f32.ln()).abs() < 1e-5);
443
444 assert!((ConservationMetrics::max_entropy(10) - 10.0_f32.ln()).abs() < 1e-5);
446 }
447
448 #[test]
449 fn test_normalized_entropy() {
450 let embeddings = make_embeddings();
451 let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
452
453 let uniform_attention = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
455 let uniform_metrics = ConservationMetrics::compute(&refs, &uniform_attention);
456 assert!((uniform_metrics.normalized_entropy(3) - 1.0).abs() < 1e-5);
457
458 let concentrated = vec![0.9, 0.05, 0.05];
460 let concentrated_metrics = ConservationMetrics::compute(&refs, &concentrated);
461 assert!(concentrated_metrics.normalized_entropy(3) < 0.5);
462 }
463
464 #[test]
465 fn test_tracker() {
466 let config = ConservationConfig::default();
467 let mut tracker = ConservationTracker::new(config);
468
469 let m1 = ConservationMetrics {
470 magnitude: 1.0,
471 energy: 0.5,
472 information: 1.0,
473 };
474
475 let m2 = ConservationMetrics {
476 magnitude: 1.001,
477 energy: 0.501,
478 information: 1.01,
479 };
480
481 tracker.record(m1);
482 tracker.record(m2);
483
484 assert!(tracker.is_conserved_from_initial().unwrap());
485 assert_eq!(tracker.history().len(), 2);
486 }
487
488 #[test]
489 fn test_weighted_centroid() {
490 let embeddings = vec![
491 vec![1.0, 0.0],
492 vec![0.0, 1.0],
493 ];
494 let refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
495
496 let attention = vec![0.5, 0.5];
498 let centroid = weighted_centroid(&refs, &attention);
499 assert!((centroid[0] - 0.5).abs() < 1e-5);
500 assert!((centroid[1] - 0.5).abs() < 1e-5);
501
502 let attention2 = vec![0.8, 0.2];
504 let centroid2 = weighted_centroid(&refs, &attention2);
505 assert!((centroid2[0] - 0.8).abs() < 1e-5);
506 assert!((centroid2[1] - 0.2).abs() < 1e-5);
507 }
508}