1use crate::distance::cosine_similarity_fast;
51use crate::trajectory::{TrajectoryCoordinate5D, DLMWeights};
52
53#[derive(Debug, Clone)]
55pub struct IRCPConfig {
56 pub temperature: f32,
58
59 pub coord_weights: DLMWeights,
61
62 pub spatial_weight: f32,
65
66 pub use_coordinate_cosine: bool,
68
69 pub min_attention: f32,
71
72 pub causal_mask: bool,
74}
75
76impl Default for IRCPConfig {
77 fn default() -> Self {
78 Self {
79 temperature: 1.0,
80 coord_weights: DLMWeights::default(),
81 spatial_weight: 0.3, use_coordinate_cosine: false,
83 min_attention: 1e-10,
84 causal_mask: false,
85 }
86 }
87}
88
89impl IRCPConfig {
90 pub fn semantic_focused() -> Self {
92 Self {
93 spatial_weight: 0.1,
94 coord_weights: DLMWeights::semantic_focused(),
95 ..Default::default()
96 }
97 }
98
99 pub fn spatial_focused() -> Self {
101 Self {
102 spatial_weight: 0.7,
103 coord_weights: DLMWeights::structural_focused(),
104 ..Default::default()
105 }
106 }
107
108 pub fn causal() -> Self {
110 Self {
111 causal_mask: true,
112 ..Default::default()
113 }
114 }
115
116 pub fn sharp() -> Self {
118 Self {
119 temperature: 0.1,
120 ..Default::default()
121 }
122 }
123
124 pub fn diffuse() -> Self {
126 Self {
127 temperature: 3.0,
128 ..Default::default()
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
135pub struct AttentionWeights {
136 pub forward: Vec<f32>,
138
139 pub inverse: Vec<f32>,
141
142 pub cross: Vec<f32>,
144
145 pub raw_scores: Vec<f32>,
147
148 pub total_mass: f32,
150}
151
152impl AttentionWeights {
153 pub fn empty() -> Self {
155 Self {
156 forward: Vec::new(),
157 inverse: Vec::new(),
158 cross: Vec::new(),
159 raw_scores: Vec::new(),
160 total_mass: 0.0,
161 }
162 }
163
164 pub fn uniform(n: usize) -> Self {
166 if n == 0 {
167 return Self::empty();
168 }
169
170 let weight = 1.0 / n as f32;
171 Self {
172 forward: vec![weight; n],
173 inverse: vec![weight; n],
174 cross: vec![weight; n],
175 raw_scores: vec![1.0; n],
176 total_mass: 1.0,
177 }
178 }
179
180 pub fn top_forward(&self) -> Option<usize> {
182 self.forward
183 .iter()
184 .enumerate()
185 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
186 .map(|(i, _)| i)
187 }
188
189 pub fn sorted_forward_indices(&self) -> Vec<usize> {
191 let mut indices: Vec<usize> = (0..self.forward.len()).collect();
192 indices.sort_by(|&a, &b| {
193 self.forward[b]
194 .partial_cmp(&self.forward[a])
195 .unwrap_or(std::cmp::Ordering::Equal)
196 });
197 indices
198 }
199
200 pub fn top_k_forward(&self, k: usize) -> Vec<usize> {
202 self.sorted_forward_indices().into_iter().take(k).collect()
203 }
204
205 pub fn forward_entropy(&self) -> f32 {
207 -self
208 .forward
209 .iter()
210 .filter(|&&w| w > 1e-10)
211 .map(|w| w * w.ln())
212 .sum::<f32>()
213 }
214
215 pub fn is_concentrated(&self, threshold: f32) -> bool {
217 self.forward_entropy() < threshold
218 }
219}
220
221#[derive(Debug, Clone)]
226pub struct IRCPPropagator {
227 config: IRCPConfig,
228}
229
230impl IRCPPropagator {
231 pub fn new(config: IRCPConfig) -> Self {
233 Self { config }
234 }
235
236 #[inline]
240 fn spatial_weight(&self, query: &TrajectoryCoordinate5D, context: &TrajectoryCoordinate5D) -> f32 {
241 if self.config.use_coordinate_cosine {
242 (1.0 + query.cosine_similarity(context)) / 2.0
244 } else {
245 let dist = query.dlm_distance(context, &self.config.coord_weights);
247 (-dist).exp()
248 }
249 }
250
251 #[inline]
255 fn semantic_weight(&self, query_emb: &[f32], context_emb: &[f32]) -> f32 {
256 (1.0 + cosine_similarity_fast(query_emb, context_emb)) / 2.0
258 }
259
260 fn softmax(&self, scores: &[f32]) -> Vec<f32> {
262 if scores.is_empty() {
263 return Vec::new();
264 }
265
266 let max_score = scores
268 .iter()
269 .copied()
270 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
271 .unwrap_or(0.0);
272
273 let exps: Vec<f32> = scores
275 .iter()
276 .map(|&s| ((s - max_score) / self.config.temperature).exp())
277 .collect();
278
279 let sum: f32 = exps.iter().sum();
281 if sum > 0.0 {
282 exps.iter().map(|e| (e / sum).max(self.config.min_attention)).collect()
283 } else {
284 vec![1.0 / scores.len() as f32; scores.len()]
285 }
286 }
287
288 pub fn compute_forward_attention(
292 &self,
293 query_coord: &TrajectoryCoordinate5D,
294 context_coords: &[TrajectoryCoordinate5D],
295 query_emb: &[f32],
296 context_embs: &[&[f32]],
297 ) -> (Vec<f32>, Vec<f32>) {
298 assert_eq!(
299 context_coords.len(),
300 context_embs.len(),
301 "Coordinate and embedding counts must match"
302 );
303
304 if context_coords.is_empty() {
305 return (Vec::new(), Vec::new());
306 }
307
308 let sw = self.config.spatial_weight;
309 let raw_scores: Vec<f32> = context_coords
310 .iter()
311 .zip(context_embs.iter())
312 .enumerate()
313 .map(|(_i, (coord, emb))| {
314 if self.config.causal_mask && coord.temporal > query_coord.temporal {
316 return 0.0;
317 }
318
319 let spatial = self.spatial_weight(query_coord, coord);
320 let semantic = self.semantic_weight(query_emb, emb);
321
322 sw * spatial + (1.0 - sw) * semantic
324 })
325 .collect();
326
327 let attention = self.softmax(&raw_scores);
328 (attention, raw_scores)
329 }
330
331 pub fn compute_inverse_attention(
339 &self,
340 forward_attention: &[f32],
341 influences: &[f32],
342 ) -> Vec<f32> {
343 assert_eq!(
344 forward_attention.len(),
345 influences.len(),
346 "Attention and influence counts must match"
347 );
348
349 if forward_attention.is_empty() {
350 return Vec::new();
351 }
352
353 let weighted: Vec<f32> = forward_attention
355 .iter()
356 .zip(influences.iter())
357 .map(|(&a, &inf)| a * inf)
358 .collect();
359
360 let sum: f32 = weighted.iter().sum();
362 if sum > 0.0 {
363 weighted.iter().map(|w| w / sum).collect()
364 } else {
365 vec![1.0 / weighted.len() as f32; weighted.len()]
366 }
367 }
368
369 pub fn compute_cross_attention(
374 &self,
375 query_coord: &TrajectoryCoordinate5D,
376 context_coords: &[TrajectoryCoordinate5D],
377 query_emb: &[f32],
378 context_embs: &[&[f32]],
379 query_is_user: bool,
380 context_is_user: &[bool],
381 ) -> Vec<f32> {
382 assert_eq!(context_coords.len(), context_is_user.len());
383
384 if context_coords.is_empty() {
385 return Vec::new();
386 }
387
388 let sw = self.config.spatial_weight;
389
390 let raw_scores: Vec<f32> = context_coords
392 .iter()
393 .zip(context_embs.iter())
394 .zip(context_is_user.iter())
395 .map(|((coord, emb), &is_user)| {
396 if is_user == query_is_user {
398 return 0.0;
399 }
400
401 if self.config.causal_mask && coord.temporal > query_coord.temporal {
403 return 0.0;
404 }
405
406 let spatial = self.spatial_weight(query_coord, coord);
407 let semantic = self.semantic_weight(query_emb, emb);
408
409 sw * spatial + (1.0 - sw) * semantic
410 })
411 .collect();
412
413 self.softmax(&raw_scores)
414 }
415
416 pub fn compute_attention(
420 &self,
421 query_coord: &TrajectoryCoordinate5D,
422 context_coords: &[TrajectoryCoordinate5D],
423 query_emb: &[f32],
424 context_embs: &[&[f32]],
425 ) -> AttentionWeights {
426 let (forward, raw_scores) =
427 self.compute_forward_attention(query_coord, context_coords, query_emb, context_embs);
428
429 if forward.is_empty() {
430 return AttentionWeights::empty();
431 }
432
433 let inverse = self.compute_inverse_attention(&forward, &forward);
435
436 let cross = forward.clone();
438
439 let total_mass = forward.iter().sum();
440
441 AttentionWeights {
442 forward,
443 inverse,
444 cross,
445 raw_scores,
446 total_mass,
447 }
448 }
449
450 pub fn compute_attention_with_roles(
452 &self,
453 query_coord: &TrajectoryCoordinate5D,
454 context_coords: &[TrajectoryCoordinate5D],
455 query_emb: &[f32],
456 context_embs: &[&[f32]],
457 query_is_user: bool,
458 context_is_user: &[bool],
459 influences: &[f32],
460 ) -> AttentionWeights {
461 let (forward, raw_scores) =
462 self.compute_forward_attention(query_coord, context_coords, query_emb, context_embs);
463
464 if forward.is_empty() {
465 return AttentionWeights::empty();
466 }
467
468 let inverse = self.compute_inverse_attention(&forward, influences);
469
470 let cross = self.compute_cross_attention(
471 query_coord,
472 context_coords,
473 query_emb,
474 context_embs,
475 query_is_user,
476 context_is_user,
477 );
478
479 let total_mass = forward.iter().sum();
480
481 AttentionWeights {
482 forward,
483 inverse,
484 cross,
485 raw_scores,
486 total_mass,
487 }
488 }
489
490 pub fn propagate_sequence(
494 &self,
495 coords: &[TrajectoryCoordinate5D],
496 embeddings: &[&[f32]],
497 ) -> Vec<AttentionWeights> {
498 assert_eq!(coords.len(), embeddings.len());
499
500 let n = coords.len();
501 if n == 0 {
502 return Vec::new();
503 }
504
505 let mut results = Vec::with_capacity(n);
506
507 for i in 0..n {
508 let context_coords: Vec<_> = coords[..i].to_vec();
510 let context_embs: Vec<_> = embeddings[..i].iter().copied().collect();
511
512 if context_coords.is_empty() {
513 results.push(AttentionWeights::empty());
514 } else {
515 let weights = self.compute_attention(
516 &coords[i],
517 &context_coords,
518 embeddings[i],
519 &context_embs,
520 );
521 results.push(weights);
522 }
523 }
524
525 results
526 }
527
528 pub fn config(&self) -> &IRCPConfig {
530 &self.config
531 }
532
533 pub fn set_config(&mut self, config: IRCPConfig) {
535 self.config = config;
536 }
537}
538
539impl Default for IRCPPropagator {
540 fn default() -> Self {
541 Self::new(IRCPConfig::default())
542 }
543}
544
545pub fn batch_compute_attention(
553 propagator: &IRCPPropagator,
554 query_coords: &[TrajectoryCoordinate5D],
555 query_embs: &[&[f32]],
556 context_coords: &[TrajectoryCoordinate5D],
557 context_embs: &[&[f32]],
558) -> Vec<AttentionWeights> {
559 query_coords
560 .iter()
561 .zip(query_embs.iter())
562 .map(|(coord, emb)| propagator.compute_attention(coord, context_coords, emb, context_embs))
563 .collect()
564}
565
566pub fn compute_attention_matrix(
570 propagator: &IRCPPropagator,
571 coords: &[TrajectoryCoordinate5D],
572 embeddings: &[&[f32]],
573) -> Vec<Vec<f32>> {
574 let n = coords.len();
575 let mut matrix = vec![vec![0.0; n]; n];
576
577 for i in 0..n {
578 let weights = propagator.compute_attention(&coords[i], coords, embeddings[i], embeddings);
579 matrix[i] = weights.forward;
580 }
581
582 matrix
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 fn make_test_coord(depth: u32, temporal: f32) -> TrajectoryCoordinate5D {
590 TrajectoryCoordinate5D::new(depth, 0, 0.8, temporal, 1)
591 }
592
593 fn make_test_embedding(seed: f32) -> Vec<f32> {
594 (0..8).map(|i| (seed + i as f32 * 0.1).sin()).collect()
595 }
596
597 #[test]
598 fn test_ircp_config_default() {
599 let config = IRCPConfig::default();
600 assert!((config.temperature - 1.0).abs() < 1e-6);
601 assert!((config.spatial_weight - 0.3).abs() < 1e-6);
602 assert!(!config.causal_mask);
603 }
604
605 #[test]
606 fn test_ircp_config_presets() {
607 let semantic = IRCPConfig::semantic_focused();
608 assert!(semantic.spatial_weight < 0.2);
609
610 let spatial = IRCPConfig::spatial_focused();
611 assert!(spatial.spatial_weight > 0.5);
612
613 let causal = IRCPConfig::causal();
614 assert!(causal.causal_mask);
615
616 let sharp = IRCPConfig::sharp();
617 assert!(sharp.temperature < 0.5);
618
619 let diffuse = IRCPConfig::diffuse();
620 assert!(diffuse.temperature > 2.0);
621 }
622
623 #[test]
624 fn test_attention_weights_uniform() {
625 let weights = AttentionWeights::uniform(5);
626 assert_eq!(weights.forward.len(), 5);
627 assert!((weights.forward[0] - 0.2).abs() < 1e-6);
628 assert!((weights.total_mass - 1.0).abs() < 1e-6);
629 }
630
631 #[test]
632 fn test_attention_weights_empty() {
633 let weights = AttentionWeights::empty();
634 assert!(weights.forward.is_empty());
635 assert!(weights.total_mass < 1e-6);
636 }
637
638 #[test]
639 fn test_attention_weights_top_k() {
640 let weights = AttentionWeights {
641 forward: vec![0.1, 0.5, 0.2, 0.15, 0.05],
642 inverse: vec![0.2; 5],
643 cross: vec![0.2; 5],
644 raw_scores: vec![1.0; 5],
645 total_mass: 1.0,
646 };
647
648 let top1 = weights.top_forward();
649 assert_eq!(top1, Some(1)); let top3 = weights.top_k_forward(3);
652 assert_eq!(top3, vec![1, 2, 3]); }
654
655 #[test]
656 fn test_propagator_empty_context() {
657 let propagator = IRCPPropagator::default();
658 let query = make_test_coord(3, 0.5);
659 let query_emb = make_test_embedding(1.0);
660
661 let weights = propagator.compute_attention(&query, &[], &query_emb, &[]);
662
663 assert!(weights.forward.is_empty());
664 assert!(weights.total_mass < 1e-6);
665 }
666
667 #[test]
668 fn test_propagator_single_context() {
669 let propagator = IRCPPropagator::default();
670 let query = make_test_coord(3, 0.5);
671 let context = vec![make_test_coord(1, 0.2)];
672 let query_emb = make_test_embedding(1.0);
673 let context_emb = make_test_embedding(1.1);
674 let context_embs: Vec<&[f32]> = vec![&context_emb];
675
676 let weights = propagator.compute_attention(&query, &context, &query_emb, &context_embs);
677
678 assert_eq!(weights.forward.len(), 1);
679 assert!((weights.forward[0] - 1.0).abs() < 1e-6); }
681
682 #[test]
683 fn test_propagator_multiple_context() {
684 let propagator = IRCPPropagator::default();
685 let query = make_test_coord(3, 0.5);
686 let context = vec![
687 make_test_coord(1, 0.1),
688 make_test_coord(2, 0.3),
689 make_test_coord(4, 0.6),
690 ];
691 let query_emb = make_test_embedding(1.0);
692 let context_emb1 = make_test_embedding(0.5);
693 let context_emb2 = make_test_embedding(0.9); let context_emb3 = make_test_embedding(2.0);
695 let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2, &context_emb3];
696
697 let weights = propagator.compute_attention(&query, &context, &query_emb, &context_embs);
698
699 assert_eq!(weights.forward.len(), 3);
700
701 let sum: f32 = weights.forward.iter().sum();
703 assert!((sum - 1.0).abs() < 1e-5);
704
705 assert!(weights.forward[1] > weights.forward[0]);
707 }
708
709 #[test]
710 fn test_propagator_causal_mask() {
711 let mut config = IRCPConfig::default();
712 config.causal_mask = true;
713 let propagator = IRCPPropagator::new(config);
714
715 let query = make_test_coord(2, 0.5); let context = vec![
717 make_test_coord(1, 0.2), make_test_coord(3, 0.8), ];
720 let query_emb = make_test_embedding(1.0);
721 let context_emb1 = make_test_embedding(1.0);
722 let context_emb2 = make_test_embedding(1.0);
723 let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2];
724
725 let weights = propagator.compute_attention(&query, &context, &query_emb, &context_embs);
726
727 assert!(weights.forward[0] > weights.forward[1]);
729 }
730
731 #[test]
732 fn test_propagator_inverse_attention() {
733 let propagator = IRCPPropagator::default();
734 let forward = vec![0.2, 0.5, 0.3];
735 let influences = vec![1.0, 0.5, 1.5]; let inverse = propagator.compute_inverse_attention(&forward, &influences);
738
739 assert_eq!(inverse.len(), 3);
740 let sum: f32 = inverse.iter().sum();
741 assert!((sum - 1.0).abs() < 1e-5);
742
743 assert!(inverse[2] > inverse[1]);
748 }
749
750 #[test]
751 fn test_propagator_cross_attention() {
752 let propagator = IRCPPropagator::default();
753
754 let query = make_test_coord(2, 0.5);
755 let query_is_user = true;
756
757 let context = vec![
758 make_test_coord(1, 0.2),
759 make_test_coord(2, 0.4),
760 make_test_coord(3, 0.6),
761 ];
762 let context_is_user = vec![false, true, false]; let query_emb = make_test_embedding(1.0);
765 let context_emb1 = make_test_embedding(1.0);
766 let context_emb2 = make_test_embedding(1.0);
767 let context_emb3 = make_test_embedding(1.0);
768 let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2, &context_emb3];
769
770 let cross = propagator.compute_cross_attention(
771 &query,
772 &context,
773 &query_emb,
774 &context_embs,
775 query_is_user,
776 &context_is_user,
777 );
778
779 assert_eq!(cross.len(), 3);
780
781 assert!(cross[0] > cross[1]); assert!(cross[2] > cross[1]); }
786
787 #[test]
788 fn test_propagate_sequence() {
789 let propagator = IRCPPropagator::default();
790
791 let coords = vec![
792 make_test_coord(0, 0.0),
793 make_test_coord(1, 0.25),
794 make_test_coord(2, 0.5),
795 make_test_coord(3, 0.75),
796 ];
797
798 let emb0 = make_test_embedding(0.0);
799 let emb1 = make_test_embedding(0.5);
800 let emb2 = make_test_embedding(1.0);
801 let emb3 = make_test_embedding(1.5);
802 let embeddings: Vec<&[f32]> = vec![&emb0, &emb1, &emb2, &emb3];
803
804 let results = propagator.propagate_sequence(&coords, &embeddings);
805
806 assert_eq!(results.len(), 4);
807
808 assert!(results[0].forward.is_empty());
810
811 assert_eq!(results[1].forward.len(), 1);
813
814 assert_eq!(results[2].forward.len(), 2);
816
817 assert_eq!(results[3].forward.len(), 3);
819 }
820
821 #[test]
822 fn test_batch_compute_attention() {
823 let propagator = IRCPPropagator::default();
824
825 let query_coords = vec![make_test_coord(3, 0.5), make_test_coord(4, 0.7)];
826
827 let context_coords = vec![make_test_coord(1, 0.1), make_test_coord(2, 0.3)];
828
829 let query_emb1 = make_test_embedding(1.0);
830 let query_emb2 = make_test_embedding(1.5);
831 let query_embs: Vec<&[f32]> = vec![&query_emb1, &query_emb2];
832
833 let context_emb1 = make_test_embedding(0.5);
834 let context_emb2 = make_test_embedding(1.0);
835 let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2];
836
837 let results = batch_compute_attention(
838 &propagator,
839 &query_coords,
840 &query_embs,
841 &context_coords,
842 &context_embs,
843 );
844
845 assert_eq!(results.len(), 2);
846 assert_eq!(results[0].forward.len(), 2);
847 assert_eq!(results[1].forward.len(), 2);
848 }
849
850 #[test]
851 fn test_compute_attention_matrix() {
852 let propagator = IRCPPropagator::default();
853
854 let coords = vec![
855 make_test_coord(0, 0.0),
856 make_test_coord(1, 0.5),
857 make_test_coord(2, 1.0),
858 ];
859
860 let emb0 = make_test_embedding(0.0);
861 let emb1 = make_test_embedding(0.5);
862 let emb2 = make_test_embedding(1.0);
863 let embeddings: Vec<&[f32]> = vec![&emb0, &emb1, &emb2];
864
865 let matrix = compute_attention_matrix(&propagator, &coords, &embeddings);
866
867 assert_eq!(matrix.len(), 3);
868 assert_eq!(matrix[0].len(), 3);
869
870 for row in &matrix {
872 let sum: f32 = row.iter().sum();
873 assert!((sum - 1.0).abs() < 1e-5);
874 }
875 }
876
877 #[test]
878 fn test_attention_entropy() {
879 let uniform = AttentionWeights::uniform(4);
881 let uniform_entropy = uniform.forward_entropy();
882
883 let concentrated = AttentionWeights {
885 forward: vec![0.97, 0.01, 0.01, 0.01],
886 inverse: vec![0.25; 4],
887 cross: vec![0.25; 4],
888 raw_scores: vec![1.0; 4],
889 total_mass: 1.0,
890 };
891 let concentrated_entropy = concentrated.forward_entropy();
892
893 assert!(uniform_entropy > concentrated_entropy);
894 assert!(concentrated.is_concentrated(0.5));
895 assert!(!uniform.is_concentrated(0.5));
896 }
897}