trueno/brick/tracing/
attention.rs1#[derive(Debug, Clone, Default)]
10pub struct AttentionWeightTrace {
11 pub layer_idx: usize,
13 pub head_idx: usize,
15 pub query_pos: usize,
17 pub top_k_positions: Vec<usize>,
19 pub top_k_weights: Vec<f32>,
21 pub tail_mass: f32,
23 pub entropy: f32,
25}
26
27impl AttentionWeightTrace {
28 pub fn from_weights(
30 layer_idx: usize,
31 head_idx: usize,
32 query_pos: usize,
33 weights: &[f32],
34 k: usize,
35 ) -> Self {
36 let k = k.min(weights.len());
37
38 let mut pairs: Vec<(usize, f32)> = weights.iter().copied().enumerate().collect();
40 pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
41
42 let top_k_positions: Vec<usize> = pairs.iter().take(k).map(|(pos, _)| *pos).collect();
43 let top_k_weights: Vec<f32> = pairs.iter().take(k).map(|(_, w)| *w).collect();
44
45 let top_k_mass: f32 = top_k_weights.iter().sum();
46 let total_mass: f32 = weights.iter().sum();
47 let tail_mass = (total_mass - top_k_mass).max(0.0);
48
49 let entropy =
51 weights.iter().filter(|&&w| w > 1e-10).map(|&w| -w * w.max(f32::EPSILON).ln()).sum();
52
53 Self { layer_idx, head_idx, query_pos, top_k_positions, top_k_weights, tail_mass, entropy }
54 }
55
56 pub fn is_attention_sink(&self, threshold: f32) -> bool {
58 self.top_k_positions.first() == Some(&0)
59 && self.top_k_weights.first().copied().unwrap_or(0.0) > threshold
60 }
61
62 pub fn is_uniform(&self, entropy_threshold: f32) -> bool {
64 self.entropy > entropy_threshold
65 }
66
67 pub fn has_recency_bias(&self, recency_window: usize, threshold: f32) -> bool {
69 if self.query_pos == 0 {
70 return false;
71 }
72 let recency_start = self.query_pos.saturating_sub(recency_window);
73 let recent_mass: f32 = self
74 .top_k_positions
75 .iter()
76 .zip(self.top_k_weights.iter())
77 .filter(|(pos, _)| **pos >= recency_start)
78 .map(|(_, w)| w)
79 .sum();
80 recent_mass > threshold
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct AttentionTraceConfig {
87 pub top_k: usize,
89 pub layers: Option<Vec<usize>>,
91 pub heads: Option<Vec<usize>>,
93 pub weight_threshold: f32,
95}
96
97impl Default for AttentionTraceConfig {
98 fn default() -> Self {
99 Self { top_k: 10, layers: None, heads: None, weight_threshold: 0.01 }
100 }
101}
102
103impl AttentionTraceConfig {
104 pub fn should_trace_layer(&self, layer_idx: usize) -> bool {
106 self.layers.as_ref().is_none_or(|layers| layers.contains(&layer_idx))
107 }
108
109 pub fn should_trace_head(&self, head_idx: usize) -> bool {
111 self.heads.as_ref().is_none_or(|heads| heads.contains(&head_idx))
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_attention_weight_trace_from_weights() {
121 let weights = vec![0.1, 0.3, 0.4, 0.2];
122 let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 2);
123
124 assert_eq!(trace.layer_idx, 0);
125 assert_eq!(trace.head_idx, 0);
126 assert_eq!(trace.query_pos, 3);
127 assert_eq!(trace.top_k_positions.len(), 2);
128 assert_eq!(trace.top_k_positions[0], 2);
130 assert_eq!(trace.top_k_positions[1], 1);
131 }
132
133 #[test]
134 fn test_attention_sink_detection() {
135 let weights = vec![0.8, 0.1, 0.05, 0.05];
136 let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 4);
137 assert!(trace.is_attention_sink(0.5));
138 }
139
140 #[test]
141 fn test_recency_bias_detection() {
142 let weights = vec![0.05, 0.4, 0.5, 0.05];
144 let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 4);
145 assert!(trace.has_recency_bias(2, 0.5));
146 }
147}