Skip to main content

candle_mi/cache/
attention.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Attention pattern cache for storing and querying per-layer attention weights.
4//!
5//! [`AttentionCache`] stores post-softmax attention patterns from each layer
6//! of a forward pass, enabling downstream analysis such as attention knockout,
7//! steering, and visualization.
8//!
9//! Each stored tensor has shape `[batch, heads, seq_q, seq_k]`.
10
11use candle_core::{DType, Tensor};
12
13use crate::error::{MIError, Result};
14
15/// Stores per-layer attention weights from a forward pass.
16///
17/// Each tensor has shape `[batch, heads, seq_q, seq_k]` — the post-softmax
18/// attention pattern for one layer.
19///
20/// # Example
21///
22/// ```
23/// use candle_mi::AttentionCache;
24/// use candle_core::{Device, Tensor};
25///
26/// let mut cache = AttentionCache::with_capacity(32);
27/// // shape [batch=1, heads=8, seq=10, seq=10]
28/// cache.push(Tensor::zeros((1, 8, 10, 10), candle_core::DType::F32, &Device::Cpu).unwrap());
29///
30/// // Query what position 5 attends to in layer 0
31/// let attn_row = cache.attention_from_position(0, 5).unwrap();
32/// ```
33#[derive(Debug)]
34pub struct AttentionCache {
35    /// Attention patterns per layer, each shape `[batch, heads, seq_q, seq_k]`.
36    patterns: Vec<Tensor>,
37}
38
39impl AttentionCache {
40    /// Create an empty cache with capacity for `n_layers` layers.
41    #[must_use]
42    pub fn with_capacity(n_layers: usize) -> Self {
43        Self {
44            patterns: Vec::with_capacity(n_layers),
45        }
46    }
47
48    /// Add an attention pattern for the next layer.
49    ///
50    /// # Shapes
51    ///
52    /// - `pattern`: `[batch, heads, seq_q, seq_k]`
53    pub fn push(&mut self, pattern: Tensor) {
54        self.patterns.push(pattern);
55    }
56
57    /// Number of cached layers.
58    #[must_use]
59    pub const fn n_layers(&self) -> usize {
60        self.patterns.len()
61    }
62
63    /// Whether the cache is empty.
64    #[must_use]
65    pub const fn is_empty(&self) -> bool {
66        self.patterns.is_empty()
67    }
68
69    /// Get the raw attention tensor for a specific layer.
70    ///
71    /// # Shapes
72    ///
73    /// - returns: `[batch, heads, seq_q, seq_k]`
74    #[must_use]
75    pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
76        self.patterns.get(layer)
77    }
78
79    /// Get attention weights FROM a specific query position, averaged across heads.
80    ///
81    /// Returns a vector of length `seq_k` representing how much `position`
82    /// attends to every key position, averaged over all attention heads
83    /// (batch index 0).
84    ///
85    /// # Shapes
86    ///
87    /// - returns: `[seq_k]` as `Vec<f32>`
88    ///
89    /// # Errors
90    ///
91    /// Returns [`MIError::Hook`] if the layer is not in the cache or the
92    /// position is out of range.
93    pub fn attention_from_position(&self, layer: usize, position: usize) -> Result<Vec<f32>> {
94        let pattern = self
95            .patterns
96            .get(layer)
97            .ok_or_else(|| MIError::Hook(format!("layer {layer} not in attention cache")))?;
98
99        let seq_q = pattern.dim(2)?;
100        if position >= seq_q {
101            return Err(MIError::Hook(format!(
102                "position {position} out of range (seq_q={seq_q})"
103            )));
104        }
105
106        // pattern: [batch, heads, seq_q, seq_k]
107        // Select batch 0, all heads, the given query position, all key positions.
108        // PROMOTE: averaging attention weights; compute in F32 for precision
109        let attn_f32 = pattern.to_dtype(DType::F32)?;
110        // narrow(dim=0, start=0, len=1) → [1, heads, seq_q, seq_k]
111        let batch0 = attn_f32.narrow(0, 0, 1)?;
112        // narrow(dim=2, start=position, len=1) → [1, heads, 1, seq_k]
113        let row = batch0.narrow(2, position, 1)?;
114        // squeeze dims 0 and 2 → [heads, seq_k]
115        let row = row.squeeze(0)?.squeeze(1)?;
116        // mean across heads (dim 0) → [seq_k]
117        let avg = row.mean(0)?;
118        let result: Vec<f32> = avg.to_vec1()?;
119        Ok(result)
120    }
121
122    /// Get attention weights TO a specific key position, averaged across heads.
123    ///
124    /// Returns a vector of length `seq_q` representing how much every query
125    /// position attends to `position`, averaged over all attention heads
126    /// (batch index 0).
127    ///
128    /// # Shapes
129    ///
130    /// - returns: `[seq_q]` as `Vec<f32>`
131    ///
132    /// # Errors
133    ///
134    /// Returns [`MIError::Hook`] if the layer is not in the cache or the
135    /// position is out of range.
136    pub fn attention_to_position(&self, layer: usize, position: usize) -> Result<Vec<f32>> {
137        let pattern = self
138            .patterns
139            .get(layer)
140            .ok_or_else(|| MIError::Hook(format!("layer {layer} not in attention cache")))?;
141
142        let seq_k = pattern.dim(3)?;
143        if position >= seq_k {
144            return Err(MIError::Hook(format!(
145                "position {position} out of range (seq_k={seq_k})"
146            )));
147        }
148
149        // pattern: [batch, heads, seq_q, seq_k]
150        // PROMOTE: averaging attention weights; compute in F32 for precision
151        let attn_f32 = pattern.to_dtype(DType::F32)?;
152        let batch0 = attn_f32.narrow(0, 0, 1)?;
153        // narrow(dim=3, start=position, len=1) → [1, heads, seq_q, 1]
154        let col = batch0.narrow(3, position, 1)?;
155        // squeeze dims 0 and 3 → [heads, seq_q]
156        let col = col.squeeze(0)?.squeeze(2)?;
157        // mean across heads (dim 0) → [seq_q]
158        let avg = col.mean(0)?;
159        let result: Vec<f32> = avg.to_vec1()?;
160        Ok(result)
161    }
162
163    /// Get the top-k key positions that a given query position attends to most.
164    ///
165    /// Returns up to `k` pairs of `(key_position, weight)` sorted by
166    /// descending attention weight, averaged across heads (batch index 0).
167    ///
168    /// # Errors
169    ///
170    /// Returns [`MIError::Hook`] if the layer is not in the cache or the
171    /// position is out of range.
172    pub fn top_attended_positions(
173        &self,
174        layer: usize,
175        from_position: usize,
176        k: usize,
177    ) -> Result<Vec<(usize, f32)>> {
178        let attn = self.attention_from_position(layer, from_position)?;
179        let mut indexed: Vec<(usize, f32)> = attn.into_iter().enumerate().collect();
180        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181        indexed.truncate(k);
182        Ok(indexed)
183    }
184
185    /// All cached patterns as a slice.
186    #[must_use]
187    pub fn patterns(&self) -> &[Tensor] {
188        &self.patterns
189    }
190}
191
192// ---------------------------------------------------------------------------
193// Tests
194// ---------------------------------------------------------------------------
195
196#[cfg(test)]
197#[allow(clippy::unwrap_used, clippy::expect_used, clippy::float_cmp)]
198mod tests {
199    use super::*;
200    use candle_core::Device;
201
202    /// Build a tiny attention cache for testing.
203    ///
204    /// Creates a single-layer cache with shape `[1, 2, 4, 4]`:
205    /// - 1 batch, 2 heads, 4 query positions, 4 key positions
206    /// - Head 0: uniform attention (0.25 everywhere)
207    /// - Head 1: identity-like (diagonal = 0.7, off-diagonal = 0.1)
208    fn sample_cache() -> AttentionCache {
209        let device = Device::Cpu;
210
211        // Head 0: uniform 0.25
212        // Head 1: diagonal-heavy
213        #[rustfmt::skip]
214        let data: Vec<f32> = vec![
215            // Head 0 (uniform)
216            0.25, 0.25, 0.25, 0.25,
217            0.25, 0.25, 0.25, 0.25,
218            0.25, 0.25, 0.25, 0.25,
219            0.25, 0.25, 0.25, 0.25,
220            // Head 1 (diagonal-heavy)
221            0.70, 0.10, 0.10, 0.10,
222            0.10, 0.70, 0.10, 0.10,
223            0.10, 0.10, 0.70, 0.10,
224            0.10, 0.10, 0.10, 0.70,
225        ];
226
227        let tensor = Tensor::from_vec(data, (1, 2, 4, 4), &device).unwrap();
228        let mut cache = AttentionCache::with_capacity(1);
229        cache.push(tensor);
230        cache
231    }
232
233    #[test]
234    fn empty_cache() {
235        let cache = AttentionCache::with_capacity(32);
236        assert_eq!(cache.n_layers(), 0);
237        assert!(cache.is_empty());
238        assert!(cache.get_layer(0).is_none());
239    }
240
241    #[test]
242    fn push_and_get_layer() {
243        let cache = sample_cache();
244        assert_eq!(cache.n_layers(), 1);
245        assert!(!cache.is_empty());
246
247        let layer0 = cache.get_layer(0).unwrap();
248        assert_eq!(layer0.dims(), &[1, 2, 4, 4]);
249        assert!(cache.get_layer(1).is_none());
250    }
251
252    #[test]
253    fn attention_from_position_values() {
254        let cache = sample_cache();
255
256        // Position 0: head0=[0.25,0.25,0.25,0.25], head1=[0.70,0.10,0.10,0.10]
257        // Average: [(0.25+0.70)/2, (0.25+0.10)/2, (0.25+0.10)/2, (0.25+0.10)/2]
258        //        = [0.475, 0.175, 0.175, 0.175]
259        let attn = cache.attention_from_position(0, 0).unwrap();
260        assert_eq!(attn.len(), 4);
261        assert!((attn[0] - 0.475).abs() < 1e-5);
262        assert!((attn[1] - 0.175).abs() < 1e-5);
263        assert!((attn[2] - 0.175).abs() < 1e-5);
264        assert!((attn[3] - 0.175).abs() < 1e-5);
265    }
266
267    #[test]
268    fn attention_from_position_out_of_range() {
269        let cache = sample_cache();
270        assert!(cache.attention_from_position(0, 10).is_err());
271        assert!(cache.attention_from_position(5, 0).is_err());
272    }
273
274    #[test]
275    fn attention_to_position_values() {
276        let cache = sample_cache();
277
278        // Key position 0: each query row's column-0 value
279        // Head 0: all rows have 0.25 at col 0 → [0.25, 0.25, 0.25, 0.25]
280        // Head 1: rows have [0.70, 0.10, 0.10, 0.10] at col 0
281        // Average: [(0.25+0.70)/2, (0.25+0.10)/2, (0.25+0.10)/2, (0.25+0.10)/2]
282        //        = [0.475, 0.175, 0.175, 0.175]
283        let attn = cache.attention_to_position(0, 0).unwrap();
284        assert_eq!(attn.len(), 4);
285        assert!((attn[0] - 0.475).abs() < 1e-5);
286        assert!((attn[1] - 0.175).abs() < 1e-5);
287    }
288
289    #[test]
290    fn attention_to_position_out_of_range() {
291        let cache = sample_cache();
292        assert!(cache.attention_to_position(0, 10).is_err());
293        assert!(cache.attention_to_position(5, 0).is_err());
294    }
295
296    #[test]
297    fn top_attended_positions_sorted() {
298        let cache = sample_cache();
299
300        // From position 0, the top-1 should be key position 0 (weight 0.475)
301        let top = cache.top_attended_positions(0, 0, 2).unwrap();
302        assert_eq!(top.len(), 2);
303        assert_eq!(top[0].0, 0); // position 0 has highest weight
304        assert!((top[0].1 - 0.475).abs() < 1e-5);
305    }
306
307    #[test]
308    fn top_attended_positions_k_larger_than_seq() {
309        let cache = sample_cache();
310
311        // k=100 but only 4 positions exist → returns all 4
312        let top = cache.top_attended_positions(0, 0, 100).unwrap();
313        assert_eq!(top.len(), 4);
314    }
315
316    #[test]
317    fn patterns_slice() {
318        let cache = sample_cache();
319        assert_eq!(cache.patterns().len(), 1);
320    }
321}