#[derive(Debug, Clone, Default)]
pub struct AttentionWeightTrace {
pub layer_idx: usize,
pub head_idx: usize,
pub query_pos: usize,
pub top_k_positions: Vec<usize>,
pub top_k_weights: Vec<f32>,
pub tail_mass: f32,
pub entropy: f32,
}
impl AttentionWeightTrace {
pub fn from_weights(
layer_idx: usize,
head_idx: usize,
query_pos: usize,
weights: &[f32],
k: usize,
) -> Self {
let k = k.min(weights.len());
let mut pairs: Vec<(usize, f32)> = weights.iter().copied().enumerate().collect();
pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k_positions: Vec<usize> = pairs.iter().take(k).map(|(pos, _)| *pos).collect();
let top_k_weights: Vec<f32> = pairs.iter().take(k).map(|(_, w)| *w).collect();
let top_k_mass: f32 = top_k_weights.iter().sum();
let total_mass: f32 = weights.iter().sum();
let tail_mass = (total_mass - top_k_mass).max(0.0);
let entropy =
weights.iter().filter(|&&w| w > 1e-10).map(|&w| -w * w.max(f32::EPSILON).ln()).sum();
Self { layer_idx, head_idx, query_pos, top_k_positions, top_k_weights, tail_mass, entropy }
}
pub fn is_attention_sink(&self, threshold: f32) -> bool {
self.top_k_positions.first() == Some(&0)
&& self.top_k_weights.first().copied().unwrap_or(0.0) > threshold
}
pub fn is_uniform(&self, entropy_threshold: f32) -> bool {
self.entropy > entropy_threshold
}
pub fn has_recency_bias(&self, recency_window: usize, threshold: f32) -> bool {
if self.query_pos == 0 {
return false;
}
let recency_start = self.query_pos.saturating_sub(recency_window);
let recent_mass: f32 = self
.top_k_positions
.iter()
.zip(self.top_k_weights.iter())
.filter(|(pos, _)| **pos >= recency_start)
.map(|(_, w)| w)
.sum();
recent_mass > threshold
}
}
#[derive(Debug, Clone)]
pub struct AttentionTraceConfig {
pub top_k: usize,
pub layers: Option<Vec<usize>>,
pub heads: Option<Vec<usize>>,
pub weight_threshold: f32,
}
impl Default for AttentionTraceConfig {
fn default() -> Self {
Self { top_k: 10, layers: None, heads: None, weight_threshold: 0.01 }
}
}
impl AttentionTraceConfig {
pub fn should_trace_layer(&self, layer_idx: usize) -> bool {
self.layers.as_ref().is_none_or(|layers| layers.contains(&layer_idx))
}
pub fn should_trace_head(&self, head_idx: usize) -> bool {
self.heads.as_ref().is_none_or(|heads| heads.contains(&head_idx))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_weight_trace_from_weights() {
let weights = vec![0.1, 0.3, 0.4, 0.2];
let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 2);
assert_eq!(trace.layer_idx, 0);
assert_eq!(trace.head_idx, 0);
assert_eq!(trace.query_pos, 3);
assert_eq!(trace.top_k_positions.len(), 2);
assert_eq!(trace.top_k_positions[0], 2);
assert_eq!(trace.top_k_positions[1], 1);
}
#[test]
fn test_attention_sink_detection() {
let weights = vec![0.8, 0.1, 0.05, 0.05];
let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 4);
assert!(trace.is_attention_sink(0.5));
}
#[test]
fn test_recency_bias_detection() {
let weights = vec![0.05, 0.4, 0.5, 0.05];
let trace = AttentionWeightTrace::from_weights(0, 0, 3, &weights, 4);
assert!(trace.has_recency_bias(2, 0.5));
}
}