Skip to main content

trueno/brick/tracing/
attention.rs

1// ============================================================================
2// E.11.3: AttentionWeightTrace (MLT-02)
3// ============================================================================
4
5/// Sparse attention weight storage for a single head.
6///
7/// Records top-k attended positions to avoid storing the full attention matrix.
8/// Useful for debugging repetition, context loss, and attention sinks.
9#[derive(Debug, Clone, Default)]
10pub struct AttentionWeightTrace {
11    /// Layer index
12    pub layer_idx: usize,
13    /// Head index within the layer
14    pub head_idx: usize,
15    /// Query position (current token being generated)
16    pub query_pos: usize,
17    /// Top-k attended positions (sorted by weight descending)
18    pub top_k_positions: Vec<usize>,
19    /// Corresponding attention weights
20    pub top_k_weights: Vec<f32>,
21    /// Sum of weights outside top-k (attention mass lost to tail)
22    pub tail_mass: f32,
23    /// Entropy of attention distribution (higher = more uniform)
24    pub entropy: f32,
25}
26
27impl AttentionWeightTrace {
28    /// Create from full attention weights, extracting top-k.
29    pub fn from_weights(
30        layer_idx: usize,
31        head_idx: usize,
32        query_pos: usize,
33        weights: &[f32],
34        k: usize,
35    ) -> Self {
36        let k = k.min(weights.len());
37
38        // Create position-weight pairs and sort by weight descending
39        let mut pairs: Vec<(usize, f32)> = weights.iter().copied().enumerate().collect();
40        pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
41
42        let top_k_positions: Vec<usize> = pairs.iter().take(k).map(|(pos, _)| *pos).collect();
43        let top_k_weights: Vec<f32> = pairs.iter().take(k).map(|(_, w)| *w).collect();
44
45        let top_k_mass: f32 = top_k_weights.iter().sum();
46        let total_mass: f32 = weights.iter().sum();
47        let tail_mass = (total_mass - top_k_mass).max(0.0);
48
49        // Compute entropy: H = -sum(p * log(p)) for non-zero probabilities
50        let entropy =
51            weights.iter().filter(|&&w| w > 1e-10).map(|&w| -w * w.max(f32::EPSILON).ln()).sum();
52
53        Self { layer_idx, head_idx, query_pos, top_k_positions, top_k_weights, tail_mass, entropy }
54    }
55
56    /// Check if attention is concentrated on first position (attention sink).
57    pub fn is_attention_sink(&self, threshold: f32) -> bool {
58        self.top_k_positions.first() == Some(&0)
59            && self.top_k_weights.first().copied().unwrap_or(0.0) > threshold
60    }
61
62    /// Check if attention is too uniform (confused model).
63    pub fn is_uniform(&self, entropy_threshold: f32) -> bool {
64        self.entropy > entropy_threshold
65    }
66
67    /// Check for repetition pattern (high weight on recent positions).
68    pub fn has_recency_bias(&self, recency_window: usize, threshold: f32) -> bool {
69        if self.query_pos == 0 {
70            return false;
71        }
72        let recency_start = self.query_pos.saturating_sub(recency_window);
73        let recent_mass: f32 = self
74            .top_k_positions
75            .iter()
76            .zip(self.top_k_weights.iter())
77            .filter(|(pos, _)| **pos >= recency_start)
78            .map(|(_, w)| w)
79            .sum();
80        recent_mass > threshold
81    }
82}
83
84/// Configuration for attention weight tracing.
85#[derive(Debug, Clone)]
86pub struct AttentionTraceConfig {
87    /// Number of top positions to record per head
88    pub top_k: usize,
89    /// Layers to trace (None = all)
90    pub layers: Option<Vec<usize>>,
91    /// Heads to trace (None = all)
92    pub heads: Option<Vec<usize>>,
93    /// Minimum weight to consider (positions with weight below this are ignored)
94    pub weight_threshold: f32,
95}
96
97impl Default for AttentionTraceConfig {
98    fn default() -> Self {
99        Self { top_k: 10, layers: None, heads: None, weight_threshold: 0.01 }
100    }
101}
102
103impl AttentionTraceConfig {
104    /// Check if a layer should be traced.
105    pub fn should_trace_layer(&self, layer_idx: usize) -> bool {
106        self.layers.as_ref().is_none_or(|layers| layers.contains(&layer_idx))
107    }
108
109    /// Check if a head should be traced.
110    pub fn should_trace_head(&self, head_idx: usize) -> bool {
111        self.heads.as_ref().is_none_or(|heads| heads.contains(&head_idx))
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_attention_weight_trace_from_weights() {
121        let weights = vec![0.1, 0.3, 0.4, 0.2];
122        let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 2);
123
124        assert_eq!(trace.layer_idx, 0);
125        assert_eq!(trace.head_idx, 0);
126        assert_eq!(trace.query_pos, 3);
127        assert_eq!(trace.top_k_positions.len(), 2);
128        // Position 2 has highest weight (0.4), then position 1 (0.3)
129        assert_eq!(trace.top_k_positions[0], 2);
130        assert_eq!(trace.top_k_positions[1], 1);
131    }
132
133    #[test]
134    fn test_attention_sink_detection() {
135        let weights = vec![0.8, 0.1, 0.05, 0.05];
136        let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 4);
137        assert!(trace.is_attention_sink(0.5));
138    }
139
140    #[test]
141    fn test_recency_bias_detection() {
142        // Position 3 attending mostly to positions 1 and 2
143        let weights = vec![0.05, 0.4, 0.5, 0.05];
144        let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 4);
145        assert!(trace.has_recency_bias(2, 0.5));
146    }
147}