rag_plusplus_core/trajectory/
ircp.rs

1//! I-RCP (Inverse Ring Contextual Propagation) Module
2//!
3//! Implements the I-RCP propagation algorithm for computing attention weights
4//! in a dual-ring trajectory structure. Based on the DLM package analysis.
5//!
6//! # Key Concepts
7//!
8//! - **Forward Attention (A_F)**: Attention from earlier nodes to later nodes (RCP)
9//! - **Inverse Attention (A_I)**: Inferred attention from context that produced a response (IRCP)
10//! - **Cross Attention (A_C)**: Attention flow between user and assistant turns
11//!
12//! # Algorithm
13//!
14//! Attention is computed as:
15//!
16//! ```text
17//! raw_score[i] = spatial_weight(query_coord, context_coord[i]) * semantic_weight(query_emb, context_emb[i])
18//! attention = softmax(raw_scores / temperature)
19//! ```
20//!
21//! Where:
22//! - `spatial_weight = exp(-distance(query_coord, context_coord))`
23//! - `semantic_weight = (1 + cosine_similarity) / 2` (normalized to [0, 1])
24//!
25//! # Usage
26//!
27//! ```ignore
28//! use rag_plusplus_core::trajectory::ircp::{IRCPPropagator, IRCPConfig, AttentionWeights};
29//! use rag_plusplus_core::trajectory::{TrajectoryCoordinate5D, DLMWeights};
30//!
31//! let config = IRCPConfig::default();
32//! let propagator = IRCPPropagator::new(config);
33//!
34//! let query_coord = TrajectoryCoordinate5D::new(3, 0, 0.9, 0.5, 1);
35//! let context_coords = vec![
36//!     TrajectoryCoordinate5D::new(1, 0, 0.8, 0.2, 1),
37//!     TrajectoryCoordinate5D::new(2, 0, 0.85, 0.4, 2),
38//! ];
39//! let query_emb = vec![0.5; 768];
40//! let context_embs = vec![vec![0.5; 768], vec![0.6; 768]];
41//!
42//! let weights = propagator.compute_attention(
43//!     &query_coord,
44//!     &context_coords,
45//!     &query_emb,
46//!     &context_embs.iter().map(|e| e.as_slice()).collect::<Vec<_>>(),
47//! );
48//! ```
49
50use crate::distance::cosine_similarity_fast;
51use crate::trajectory::{TrajectoryCoordinate5D, DLMWeights};
52
53/// Configuration for I-RCP propagation.
54#[derive(Debug, Clone)]
55pub struct IRCPConfig {
56    /// Temperature for softmax (lower = sharper attention)
57    pub temperature: f32,
58
59    /// Weight configuration for coordinate distance
60    pub coord_weights: DLMWeights,
61
62    /// Relative weight of spatial vs semantic components [0, 1]
63    /// 0 = pure semantic, 1 = pure spatial
64    pub spatial_weight: f32,
65
66    /// Whether to use cosine distance (true) or coordinate distance (false) for spatial
67    pub use_coordinate_cosine: bool,
68
69    /// Minimum attention weight (prevents division by zero)
70    pub min_attention: f32,
71
72    /// Whether to apply causal masking (future nodes get zero attention)
73    pub causal_mask: bool,
74}
75
76impl Default for IRCPConfig {
77    fn default() -> Self {
78        Self {
79            temperature: 1.0,
80            coord_weights: DLMWeights::default(),
81            spatial_weight: 0.3, // Favor semantic similarity
82            use_coordinate_cosine: false,
83            min_attention: 1e-10,
84            causal_mask: false,
85        }
86    }
87}
88
89impl IRCPConfig {
90    /// Configuration that heavily weights semantic similarity.
91    pub fn semantic_focused() -> Self {
92        Self {
93            spatial_weight: 0.1,
94            coord_weights: DLMWeights::semantic_focused(),
95            ..Default::default()
96        }
97    }
98
99    /// Configuration that heavily weights coordinate distance.
100    pub fn spatial_focused() -> Self {
101        Self {
102            spatial_weight: 0.7,
103            coord_weights: DLMWeights::structural_focused(),
104            ..Default::default()
105        }
106    }
107
108    /// Configuration for causal attention (no looking at future).
109    pub fn causal() -> Self {
110        Self {
111            causal_mask: true,
112            ..Default::default()
113        }
114    }
115
116    /// Sharp attention (low temperature).
117    pub fn sharp() -> Self {
118        Self {
119            temperature: 0.1,
120            ..Default::default()
121        }
122    }
123
124    /// Diffuse attention (high temperature).
125    pub fn diffuse() -> Self {
126        Self {
127            temperature: 3.0,
128            ..Default::default()
129        }
130    }
131}
132
133/// Computed attention weights from I-RCP propagation.
134#[derive(Debug, Clone)]
135pub struct AttentionWeights {
136    /// Forward attention weights (A_F): query → context
137    pub forward: Vec<f32>,
138
139    /// Inverse attention weights (A_I): context → query (inferred)
140    pub inverse: Vec<f32>,
141
142    /// Cross attention weights (A_C): between user/assistant turns
143    pub cross: Vec<f32>,
144
145    /// Raw scores before softmax (for debugging)
146    pub raw_scores: Vec<f32>,
147
148    /// Total attention mass (should be ~1.0 after softmax)
149    pub total_mass: f32,
150}
151
152impl AttentionWeights {
153    /// Create empty attention weights.
154    pub fn empty() -> Self {
155        Self {
156            forward: Vec::new(),
157            inverse: Vec::new(),
158            cross: Vec::new(),
159            raw_scores: Vec::new(),
160            total_mass: 0.0,
161        }
162    }
163
164    /// Create uniform attention weights.
165    pub fn uniform(n: usize) -> Self {
166        if n == 0 {
167            return Self::empty();
168        }
169
170        let weight = 1.0 / n as f32;
171        Self {
172            forward: vec![weight; n],
173            inverse: vec![weight; n],
174            cross: vec![weight; n],
175            raw_scores: vec![1.0; n],
176            total_mass: 1.0,
177        }
178    }
179
180    /// Get the index with highest forward attention.
181    pub fn top_forward(&self) -> Option<usize> {
182        self.forward
183            .iter()
184            .enumerate()
185            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
186            .map(|(i, _)| i)
187    }
188
189    /// Get indices sorted by forward attention (descending).
190    pub fn sorted_forward_indices(&self) -> Vec<usize> {
191        let mut indices: Vec<usize> = (0..self.forward.len()).collect();
192        indices.sort_by(|&a, &b| {
193            self.forward[b]
194                .partial_cmp(&self.forward[a])
195                .unwrap_or(std::cmp::Ordering::Equal)
196        });
197        indices
198    }
199
200    /// Get top-k indices by forward attention.
201    pub fn top_k_forward(&self, k: usize) -> Vec<usize> {
202        self.sorted_forward_indices().into_iter().take(k).collect()
203    }
204
205    /// Compute entropy of forward attention distribution.
206    pub fn forward_entropy(&self) -> f32 {
207        -self
208            .forward
209            .iter()
210            .filter(|&&w| w > 1e-10)
211            .map(|w| w * w.ln())
212            .sum::<f32>()
213    }
214
215    /// Check if attention is concentrated (low entropy).
216    pub fn is_concentrated(&self, threshold: f32) -> bool {
217        self.forward_entropy() < threshold
218    }
219}
220
221/// I-RCP propagation engine.
222///
223/// Computes attention weights for queries over context sets using
224/// both spatial (coordinate) and semantic (embedding) similarity.
225#[derive(Debug, Clone)]
226pub struct IRCPPropagator {
227    config: IRCPConfig,
228}
229
230impl IRCPPropagator {
231    /// Create a new propagator with configuration.
232    pub fn new(config: IRCPConfig) -> Self {
233        Self { config }
234    }
235
236    /// Compute spatial weight between two coordinates.
237    ///
238    /// Higher weight = closer in coordinate space.
239    #[inline]
240    fn spatial_weight(&self, query: &TrajectoryCoordinate5D, context: &TrajectoryCoordinate5D) -> f32 {
241        if self.config.use_coordinate_cosine {
242            // Use cosine similarity of coordinates (direction-based)
243            (1.0 + query.cosine_similarity(context)) / 2.0
244        } else {
245            // Use exponential decay of coordinate distance
246            let dist = query.dlm_distance(context, &self.config.coord_weights);
247            (-dist).exp()
248        }
249    }
250
251    /// Compute semantic weight between two embeddings.
252    ///
253    /// Higher weight = more semantically similar.
254    #[inline]
255    fn semantic_weight(&self, query_emb: &[f32], context_emb: &[f32]) -> f32 {
256        // Normalize cosine similarity from [-1, 1] to [0, 1]
257        (1.0 + cosine_similarity_fast(query_emb, context_emb)) / 2.0
258    }
259
260    /// Apply softmax to raw scores.
261    fn softmax(&self, scores: &[f32]) -> Vec<f32> {
262        if scores.is_empty() {
263            return Vec::new();
264        }
265
266        // Find max for numerical stability
267        let max_score = scores
268            .iter()
269            .copied()
270            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
271            .unwrap_or(0.0);
272
273        // Compute exp((x - max) / temperature)
274        let exps: Vec<f32> = scores
275            .iter()
276            .map(|&s| ((s - max_score) / self.config.temperature).exp())
277            .collect();
278
279        // Normalize
280        let sum: f32 = exps.iter().sum();
281        if sum > 0.0 {
282            exps.iter().map(|e| (e / sum).max(self.config.min_attention)).collect()
283        } else {
284            vec![1.0 / scores.len() as f32; scores.len()]
285        }
286    }
287
288    /// Compute forward attention weights (A_F).
289    ///
290    /// Forward attention: how much should the query attend to each context item?
291    pub fn compute_forward_attention(
292        &self,
293        query_coord: &TrajectoryCoordinate5D,
294        context_coords: &[TrajectoryCoordinate5D],
295        query_emb: &[f32],
296        context_embs: &[&[f32]],
297    ) -> (Vec<f32>, Vec<f32>) {
298        assert_eq!(
299            context_coords.len(),
300            context_embs.len(),
301            "Coordinate and embedding counts must match"
302        );
303
304        if context_coords.is_empty() {
305            return (Vec::new(), Vec::new());
306        }
307
308        let sw = self.config.spatial_weight;
309        let raw_scores: Vec<f32> = context_coords
310            .iter()
311            .zip(context_embs.iter())
312            .enumerate()
313            .map(|(_i, (coord, emb))| {
314                // Apply causal mask if enabled
315                if self.config.causal_mask && coord.temporal > query_coord.temporal {
316                    return 0.0;
317                }
318
319                let spatial = self.spatial_weight(query_coord, coord);
320                let semantic = self.semantic_weight(query_emb, emb);
321
322                // Weighted combination
323                sw * spatial + (1.0 - sw) * semantic
324            })
325            .collect();
326
327        let attention = self.softmax(&raw_scores);
328        (attention, raw_scores)
329    }
330
331    /// Compute inverse attention weights (A_I).
332    ///
333    /// Inverse attention: given a response, infer what context produced it.
334    /// This is the "inverse" of forward attention.
335    ///
336    /// The inverse is computed by normalizing attention received:
337    /// A_I[i] = A_F[i] * influence[i] / sum(A_F * influence)
338    pub fn compute_inverse_attention(
339        &self,
340        forward_attention: &[f32],
341        influences: &[f32],
342    ) -> Vec<f32> {
343        assert_eq!(
344            forward_attention.len(),
345            influences.len(),
346            "Attention and influence counts must match"
347        );
348
349        if forward_attention.is_empty() {
350            return Vec::new();
351        }
352
353        // Weight attention by influence
354        let weighted: Vec<f32> = forward_attention
355            .iter()
356            .zip(influences.iter())
357            .map(|(&a, &inf)| a * inf)
358            .collect();
359
360        // Normalize
361        let sum: f32 = weighted.iter().sum();
362        if sum > 0.0 {
363            weighted.iter().map(|w| w / sum).collect()
364        } else {
365            vec![1.0 / weighted.len() as f32; weighted.len()]
366        }
367    }
368
369    /// Compute cross attention weights (A_C).
370    ///
371    /// Cross attention captures flow between user and assistant turns.
372    /// User turns receive attention from assistant context, and vice versa.
373    pub fn compute_cross_attention(
374        &self,
375        query_coord: &TrajectoryCoordinate5D,
376        context_coords: &[TrajectoryCoordinate5D],
377        query_emb: &[f32],
378        context_embs: &[&[f32]],
379        query_is_user: bool,
380        context_is_user: &[bool],
381    ) -> Vec<f32> {
382        assert_eq!(context_coords.len(), context_is_user.len());
383
384        if context_coords.is_empty() {
385            return Vec::new();
386        }
387
388        let sw = self.config.spatial_weight;
389
390        // Cross attention only applies to opposite roles
391        let raw_scores: Vec<f32> = context_coords
392            .iter()
393            .zip(context_embs.iter())
394            .zip(context_is_user.iter())
395            .map(|((coord, emb), &is_user)| {
396                // Only attend to opposite role
397                if is_user == query_is_user {
398                    return 0.0;
399                }
400
401                // Apply causal mask if enabled
402                if self.config.causal_mask && coord.temporal > query_coord.temporal {
403                    return 0.0;
404                }
405
406                let spatial = self.spatial_weight(query_coord, coord);
407                let semantic = self.semantic_weight(query_emb, emb);
408
409                sw * spatial + (1.0 - sw) * semantic
410            })
411            .collect();
412
413        self.softmax(&raw_scores)
414    }
415
416    /// Compute full I-RCP attention weights.
417    ///
418    /// Returns forward, inverse, and cross attention in one call.
419    pub fn compute_attention(
420        &self,
421        query_coord: &TrajectoryCoordinate5D,
422        context_coords: &[TrajectoryCoordinate5D],
423        query_emb: &[f32],
424        context_embs: &[&[f32]],
425    ) -> AttentionWeights {
426        let (forward, raw_scores) =
427            self.compute_forward_attention(query_coord, context_coords, query_emb, context_embs);
428
429        if forward.is_empty() {
430            return AttentionWeights::empty();
431        }
432
433        // For inverse, use forward attention as influence proxy
434        let inverse = self.compute_inverse_attention(&forward, &forward);
435
436        // For cross, we'd need role information - use forward as placeholder
437        let cross = forward.clone();
438
439        let total_mass = forward.iter().sum();
440
441        AttentionWeights {
442            forward,
443            inverse,
444            cross,
445            raw_scores,
446            total_mass,
447        }
448    }
449
450    /// Compute attention with role information for cross attention.
451    pub fn compute_attention_with_roles(
452        &self,
453        query_coord: &TrajectoryCoordinate5D,
454        context_coords: &[TrajectoryCoordinate5D],
455        query_emb: &[f32],
456        context_embs: &[&[f32]],
457        query_is_user: bool,
458        context_is_user: &[bool],
459        influences: &[f32],
460    ) -> AttentionWeights {
461        let (forward, raw_scores) =
462            self.compute_forward_attention(query_coord, context_coords, query_emb, context_embs);
463
464        if forward.is_empty() {
465            return AttentionWeights::empty();
466        }
467
468        let inverse = self.compute_inverse_attention(&forward, influences);
469
470        let cross = self.compute_cross_attention(
471            query_coord,
472            context_coords,
473            query_emb,
474            context_embs,
475            query_is_user,
476            context_is_user,
477        );
478
479        let total_mass = forward.iter().sum();
480
481        AttentionWeights {
482            forward,
483            inverse,
484            cross,
485            raw_scores,
486            total_mass,
487        }
488    }
489
490    /// Propagate attention through a sequence of queries.
491    ///
492    /// Returns attention weights for each query position.
493    pub fn propagate_sequence(
494        &self,
495        coords: &[TrajectoryCoordinate5D],
496        embeddings: &[&[f32]],
497    ) -> Vec<AttentionWeights> {
498        assert_eq!(coords.len(), embeddings.len());
499
500        let n = coords.len();
501        if n == 0 {
502            return Vec::new();
503        }
504
505        let mut results = Vec::with_capacity(n);
506
507        for i in 0..n {
508            // Context is all nodes before current position
509            let context_coords: Vec<_> = coords[..i].to_vec();
510            let context_embs: Vec<_> = embeddings[..i].iter().copied().collect();
511
512            if context_coords.is_empty() {
513                results.push(AttentionWeights::empty());
514            } else {
515                let weights = self.compute_attention(
516                    &coords[i],
517                    &context_coords,
518                    embeddings[i],
519                    &context_embs,
520                );
521                results.push(weights);
522            }
523        }
524
525        results
526    }
527
528    /// Get the config.
529    pub fn config(&self) -> &IRCPConfig {
530        &self.config
531    }
532
533    /// Update config.
534    pub fn set_config(&mut self, config: IRCPConfig) {
535        self.config = config;
536    }
537}
538
539impl Default for IRCPPropagator {
540    fn default() -> Self {
541        Self::new(IRCPConfig::default())
542    }
543}
544
545// ============================================================================
546// Batch Operations
547// ============================================================================
548
549/// Batch compute attention for multiple queries against shared context.
550///
551/// More efficient than calling compute_attention repeatedly when context is shared.
552pub fn batch_compute_attention(
553    propagator: &IRCPPropagator,
554    query_coords: &[TrajectoryCoordinate5D],
555    query_embs: &[&[f32]],
556    context_coords: &[TrajectoryCoordinate5D],
557    context_embs: &[&[f32]],
558) -> Vec<AttentionWeights> {
559    query_coords
560        .iter()
561        .zip(query_embs.iter())
562        .map(|(coord, emb)| propagator.compute_attention(coord, context_coords, emb, context_embs))
563        .collect()
564}
565
566/// Compute attention matrix between all pairs.
567///
568/// Returns n x n matrix where element [i][j] is attention from i to j.
569pub fn compute_attention_matrix(
570    propagator: &IRCPPropagator,
571    coords: &[TrajectoryCoordinate5D],
572    embeddings: &[&[f32]],
573) -> Vec<Vec<f32>> {
574    let n = coords.len();
575    let mut matrix = vec![vec![0.0; n]; n];
576
577    for i in 0..n {
578        let weights = propagator.compute_attention(&coords[i], coords, embeddings[i], embeddings);
579        matrix[i] = weights.forward;
580    }
581
582    matrix
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    fn make_test_coord(depth: u32, temporal: f32) -> TrajectoryCoordinate5D {
590        TrajectoryCoordinate5D::new(depth, 0, 0.8, temporal, 1)
591    }
592
593    fn make_test_embedding(seed: f32) -> Vec<f32> {
594        (0..8).map(|i| (seed + i as f32 * 0.1).sin()).collect()
595    }
596
597    #[test]
598    fn test_ircp_config_default() {
599        let config = IRCPConfig::default();
600        assert!((config.temperature - 1.0).abs() < 1e-6);
601        assert!((config.spatial_weight - 0.3).abs() < 1e-6);
602        assert!(!config.causal_mask);
603    }
604
605    #[test]
606    fn test_ircp_config_presets() {
607        let semantic = IRCPConfig::semantic_focused();
608        assert!(semantic.spatial_weight < 0.2);
609
610        let spatial = IRCPConfig::spatial_focused();
611        assert!(spatial.spatial_weight > 0.5);
612
613        let causal = IRCPConfig::causal();
614        assert!(causal.causal_mask);
615
616        let sharp = IRCPConfig::sharp();
617        assert!(sharp.temperature < 0.5);
618
619        let diffuse = IRCPConfig::diffuse();
620        assert!(diffuse.temperature > 2.0);
621    }
622
623    #[test]
624    fn test_attention_weights_uniform() {
625        let weights = AttentionWeights::uniform(5);
626        assert_eq!(weights.forward.len(), 5);
627        assert!((weights.forward[0] - 0.2).abs() < 1e-6);
628        assert!((weights.total_mass - 1.0).abs() < 1e-6);
629    }
630
631    #[test]
632    fn test_attention_weights_empty() {
633        let weights = AttentionWeights::empty();
634        assert!(weights.forward.is_empty());
635        assert!(weights.total_mass < 1e-6);
636    }
637
638    #[test]
639    fn test_attention_weights_top_k() {
640        let weights = AttentionWeights {
641            forward: vec![0.1, 0.5, 0.2, 0.15, 0.05],
642            inverse: vec![0.2; 5],
643            cross: vec![0.2; 5],
644            raw_scores: vec![1.0; 5],
645            total_mass: 1.0,
646        };
647
648        let top1 = weights.top_forward();
649        assert_eq!(top1, Some(1)); // Index 1 has 0.5
650
651        let top3 = weights.top_k_forward(3);
652        assert_eq!(top3, vec![1, 2, 3]); // 0.5, 0.2, 0.15
653    }
654
655    #[test]
656    fn test_propagator_empty_context() {
657        let propagator = IRCPPropagator::default();
658        let query = make_test_coord(3, 0.5);
659        let query_emb = make_test_embedding(1.0);
660
661        let weights = propagator.compute_attention(&query, &[], &query_emb, &[]);
662
663        assert!(weights.forward.is_empty());
664        assert!(weights.total_mass < 1e-6);
665    }
666
667    #[test]
668    fn test_propagator_single_context() {
669        let propagator = IRCPPropagator::default();
670        let query = make_test_coord(3, 0.5);
671        let context = vec![make_test_coord(1, 0.2)];
672        let query_emb = make_test_embedding(1.0);
673        let context_emb = make_test_embedding(1.1);
674        let context_embs: Vec<&[f32]> = vec![&context_emb];
675
676        let weights = propagator.compute_attention(&query, &context, &query_emb, &context_embs);
677
678        assert_eq!(weights.forward.len(), 1);
679        assert!((weights.forward[0] - 1.0).abs() < 1e-6); // Single element gets all attention
680    }
681
682    #[test]
683    fn test_propagator_multiple_context() {
684        let propagator = IRCPPropagator::default();
685        let query = make_test_coord(3, 0.5);
686        let context = vec![
687            make_test_coord(1, 0.1),
688            make_test_coord(2, 0.3),
689            make_test_coord(4, 0.6),
690        ];
691        let query_emb = make_test_embedding(1.0);
692        let context_emb1 = make_test_embedding(0.5);
693        let context_emb2 = make_test_embedding(0.9); // More similar to query
694        let context_emb3 = make_test_embedding(2.0);
695        let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2, &context_emb3];
696
697        let weights = propagator.compute_attention(&query, &context, &query_emb, &context_embs);
698
699        assert_eq!(weights.forward.len(), 3);
700
701        // All attention weights should sum to 1.0
702        let sum: f32 = weights.forward.iter().sum();
703        assert!((sum - 1.0).abs() < 1e-5);
704
705        // Second context should have higher attention (more similar embedding)
706        assert!(weights.forward[1] > weights.forward[0]);
707    }
708
709    #[test]
710    fn test_propagator_causal_mask() {
711        let mut config = IRCPConfig::default();
712        config.causal_mask = true;
713        let propagator = IRCPPropagator::new(config);
714
715        let query = make_test_coord(2, 0.5); // temporal = 0.5
716        let context = vec![
717            make_test_coord(1, 0.2), // Before query (should attend)
718            make_test_coord(3, 0.8), // After query (should be masked)
719        ];
720        let query_emb = make_test_embedding(1.0);
721        let context_emb1 = make_test_embedding(1.0);
722        let context_emb2 = make_test_embedding(1.0);
723        let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2];
724
725        let weights = propagator.compute_attention(&query, &context, &query_emb, &context_embs);
726
727        // Future context should have near-zero attention
728        assert!(weights.forward[0] > weights.forward[1]);
729    }
730
731    #[test]
732    fn test_propagator_inverse_attention() {
733        let propagator = IRCPPropagator::default();
734        let forward = vec![0.2, 0.5, 0.3];
735        let influences = vec![1.0, 0.5, 1.5]; // Different influences
736
737        let inverse = propagator.compute_inverse_attention(&forward, &influences);
738
739        assert_eq!(inverse.len(), 3);
740        let sum: f32 = inverse.iter().sum();
741        assert!((sum - 1.0).abs() < 1e-5);
742
743        // High influence should boost attention
744        // forward[2] * influence[2] = 0.3 * 1.5 = 0.45
745        // forward[1] * influence[1] = 0.5 * 0.5 = 0.25
746        // So inverse[2] should be higher than inverse[1]
747        assert!(inverse[2] > inverse[1]);
748    }
749
750    #[test]
751    fn test_propagator_cross_attention() {
752        let propagator = IRCPPropagator::default();
753
754        let query = make_test_coord(2, 0.5);
755        let query_is_user = true;
756
757        let context = vec![
758            make_test_coord(1, 0.2),
759            make_test_coord(2, 0.4),
760            make_test_coord(3, 0.6),
761        ];
762        let context_is_user = vec![false, true, false]; // User only attends to assistant
763
764        let query_emb = make_test_embedding(1.0);
765        let context_emb1 = make_test_embedding(1.0);
766        let context_emb2 = make_test_embedding(1.0);
767        let context_emb3 = make_test_embedding(1.0);
768        let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2, &context_emb3];
769
770        let cross = propagator.compute_cross_attention(
771            &query,
772            &context,
773            &query_emb,
774            &context_embs,
775            query_is_user,
776            &context_is_user,
777        );
778
779        assert_eq!(cross.len(), 3);
780
781        // Same-role context should have zero/minimal attention
782        // context[1] is also user, so should be masked
783        assert!(cross[0] > cross[1]); // Assistant > User
784        assert!(cross[2] > cross[1]); // Assistant > User
785    }
786
787    #[test]
788    fn test_propagate_sequence() {
789        let propagator = IRCPPropagator::default();
790
791        let coords = vec![
792            make_test_coord(0, 0.0),
793            make_test_coord(1, 0.25),
794            make_test_coord(2, 0.5),
795            make_test_coord(3, 0.75),
796        ];
797
798        let emb0 = make_test_embedding(0.0);
799        let emb1 = make_test_embedding(0.5);
800        let emb2 = make_test_embedding(1.0);
801        let emb3 = make_test_embedding(1.5);
802        let embeddings: Vec<&[f32]> = vec![&emb0, &emb1, &emb2, &emb3];
803
804        let results = propagator.propagate_sequence(&coords, &embeddings);
805
806        assert_eq!(results.len(), 4);
807
808        // First position has no context
809        assert!(results[0].forward.is_empty());
810
811        // Second position has 1 context
812        assert_eq!(results[1].forward.len(), 1);
813
814        // Third position has 2 contexts
815        assert_eq!(results[2].forward.len(), 2);
816
817        // Fourth position has 3 contexts
818        assert_eq!(results[3].forward.len(), 3);
819    }
820
821    #[test]
822    fn test_batch_compute_attention() {
823        let propagator = IRCPPropagator::default();
824
825        let query_coords = vec![make_test_coord(3, 0.5), make_test_coord(4, 0.7)];
826
827        let context_coords = vec![make_test_coord(1, 0.1), make_test_coord(2, 0.3)];
828
829        let query_emb1 = make_test_embedding(1.0);
830        let query_emb2 = make_test_embedding(1.5);
831        let query_embs: Vec<&[f32]> = vec![&query_emb1, &query_emb2];
832
833        let context_emb1 = make_test_embedding(0.5);
834        let context_emb2 = make_test_embedding(1.0);
835        let context_embs: Vec<&[f32]> = vec![&context_emb1, &context_emb2];
836
837        let results = batch_compute_attention(
838            &propagator,
839            &query_coords,
840            &query_embs,
841            &context_coords,
842            &context_embs,
843        );
844
845        assert_eq!(results.len(), 2);
846        assert_eq!(results[0].forward.len(), 2);
847        assert_eq!(results[1].forward.len(), 2);
848    }
849
850    #[test]
851    fn test_compute_attention_matrix() {
852        let propagator = IRCPPropagator::default();
853
854        let coords = vec![
855            make_test_coord(0, 0.0),
856            make_test_coord(1, 0.5),
857            make_test_coord(2, 1.0),
858        ];
859
860        let emb0 = make_test_embedding(0.0);
861        let emb1 = make_test_embedding(0.5);
862        let emb2 = make_test_embedding(1.0);
863        let embeddings: Vec<&[f32]> = vec![&emb0, &emb1, &emb2];
864
865        let matrix = compute_attention_matrix(&propagator, &coords, &embeddings);
866
867        assert_eq!(matrix.len(), 3);
868        assert_eq!(matrix[0].len(), 3);
869
870        // Each row should sum to 1.0
871        for row in &matrix {
872            let sum: f32 = row.iter().sum();
873            assert!((sum - 1.0).abs() < 1e-5);
874        }
875    }
876
877    #[test]
878    fn test_attention_entropy() {
879        // Uniform attention has high entropy
880        let uniform = AttentionWeights::uniform(4);
881        let uniform_entropy = uniform.forward_entropy();
882
883        // Concentrated attention has low entropy
884        let concentrated = AttentionWeights {
885            forward: vec![0.97, 0.01, 0.01, 0.01],
886            inverse: vec![0.25; 4],
887            cross: vec![0.25; 4],
888            raw_scores: vec![1.0; 4],
889            total_mass: 1.0,
890        };
891        let concentrated_entropy = concentrated.forward_entropy();
892
893        assert!(uniform_entropy > concentrated_entropy);
894        assert!(concentrated.is_concentrated(0.5));
895        assert!(!uniform.is_concentrated(0.5));
896    }
897}