candle_mi/cache/attention.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Attention pattern cache for storing and querying per-layer attention weights.
4//!
5//! [`AttentionCache`] stores post-softmax attention patterns from each layer
6//! of a forward pass, enabling downstream analysis such as attention knockout,
7//! steering, and visualization.
8//!
9//! Each stored tensor has shape `[batch, heads, seq_q, seq_k]`.
10
11use candle_core::{DType, Tensor};
12
13use crate::error::{MIError, Result};
14
15/// Stores per-layer attention weights from a forward pass.
16///
17/// Each tensor has shape `[batch, heads, seq_q, seq_k]` — the post-softmax
18/// attention pattern for one layer.
19///
20/// # Example
21///
22/// ```
23/// use candle_mi::AttentionCache;
24/// use candle_core::{Device, Tensor};
25///
26/// let mut cache = AttentionCache::with_capacity(32);
27/// // shape [batch=1, heads=8, seq=10, seq=10]
28/// cache.push(Tensor::zeros((1, 8, 10, 10), candle_core::DType::F32, &Device::Cpu).unwrap());
29///
30/// // Query what position 5 attends to in layer 0
31/// let attn_row = cache.attention_from_position(0, 5).unwrap();
32/// ```
33#[derive(Debug)]
34pub struct AttentionCache {
35 /// Attention patterns per layer, each shape `[batch, heads, seq_q, seq_k]`.
36 patterns: Vec<Tensor>,
37}
38
39impl AttentionCache {
40 /// Create an empty cache with capacity for `n_layers` layers.
41 #[must_use]
42 pub fn with_capacity(n_layers: usize) -> Self {
43 Self {
44 patterns: Vec::with_capacity(n_layers),
45 }
46 }
47
48 /// Add an attention pattern for the next layer.
49 ///
50 /// # Shapes
51 ///
52 /// - `pattern`: `[batch, heads, seq_q, seq_k]`
53 pub fn push(&mut self, pattern: Tensor) {
54 self.patterns.push(pattern);
55 }
56
57 /// Number of cached layers.
58 #[must_use]
59 pub const fn n_layers(&self) -> usize {
60 self.patterns.len()
61 }
62
63 /// Whether the cache is empty.
64 #[must_use]
65 pub const fn is_empty(&self) -> bool {
66 self.patterns.is_empty()
67 }
68
69 /// Get the raw attention tensor for a specific layer.
70 ///
71 /// # Shapes
72 ///
73 /// - returns: `[batch, heads, seq_q, seq_k]`
74 #[must_use]
75 pub fn get_layer(&self, layer: usize) -> Option<&Tensor> {
76 self.patterns.get(layer)
77 }
78
79 /// Get attention weights FROM a specific query position, averaged across heads.
80 ///
81 /// Returns a vector of length `seq_k` representing how much `position`
82 /// attends to every key position, averaged over all attention heads
83 /// (batch index 0).
84 ///
85 /// # Shapes
86 ///
87 /// - returns: `[seq_k]` as `Vec<f32>`
88 ///
89 /// # Errors
90 ///
91 /// Returns [`MIError::Hook`] if the layer is not in the cache or the
92 /// position is out of range.
93 pub fn attention_from_position(&self, layer: usize, position: usize) -> Result<Vec<f32>> {
94 let pattern = self
95 .patterns
96 .get(layer)
97 .ok_or_else(|| MIError::Hook(format!("layer {layer} not in attention cache")))?;
98
99 let seq_q = pattern.dim(2)?;
100 if position >= seq_q {
101 return Err(MIError::Hook(format!(
102 "position {position} out of range (seq_q={seq_q})"
103 )));
104 }
105
106 // pattern: [batch, heads, seq_q, seq_k]
107 // Select batch 0, all heads, the given query position, all key positions.
108 // PROMOTE: averaging attention weights; compute in F32 for precision
109 let attn_f32 = pattern.to_dtype(DType::F32)?;
110 // narrow(dim=0, start=0, len=1) → [1, heads, seq_q, seq_k]
111 let batch0 = attn_f32.narrow(0, 0, 1)?;
112 // narrow(dim=2, start=position, len=1) → [1, heads, 1, seq_k]
113 let row = batch0.narrow(2, position, 1)?;
114 // squeeze dims 0 and 2 → [heads, seq_k]
115 let row = row.squeeze(0)?.squeeze(1)?;
116 // mean across heads (dim 0) → [seq_k]
117 let avg = row.mean(0)?;
118 let result: Vec<f32> = avg.to_vec1()?;
119 Ok(result)
120 }
121
122 /// Get attention weights TO a specific key position, averaged across heads.
123 ///
124 /// Returns a vector of length `seq_q` representing how much every query
125 /// position attends to `position`, averaged over all attention heads
126 /// (batch index 0).
127 ///
128 /// # Shapes
129 ///
130 /// - returns: `[seq_q]` as `Vec<f32>`
131 ///
132 /// # Errors
133 ///
134 /// Returns [`MIError::Hook`] if the layer is not in the cache or the
135 /// position is out of range.
136 pub fn attention_to_position(&self, layer: usize, position: usize) -> Result<Vec<f32>> {
137 let pattern = self
138 .patterns
139 .get(layer)
140 .ok_or_else(|| MIError::Hook(format!("layer {layer} not in attention cache")))?;
141
142 let seq_k = pattern.dim(3)?;
143 if position >= seq_k {
144 return Err(MIError::Hook(format!(
145 "position {position} out of range (seq_k={seq_k})"
146 )));
147 }
148
149 // pattern: [batch, heads, seq_q, seq_k]
150 // PROMOTE: averaging attention weights; compute in F32 for precision
151 let attn_f32 = pattern.to_dtype(DType::F32)?;
152 let batch0 = attn_f32.narrow(0, 0, 1)?;
153 // narrow(dim=3, start=position, len=1) → [1, heads, seq_q, 1]
154 let col = batch0.narrow(3, position, 1)?;
155 // squeeze dims 0 and 3 → [heads, seq_q]
156 let col = col.squeeze(0)?.squeeze(2)?;
157 // mean across heads (dim 0) → [seq_q]
158 let avg = col.mean(0)?;
159 let result: Vec<f32> = avg.to_vec1()?;
160 Ok(result)
161 }
162
163 /// Get the top-k key positions that a given query position attends to most.
164 ///
165 /// Returns up to `k` pairs of `(key_position, weight)` sorted by
166 /// descending attention weight, averaged across heads (batch index 0).
167 ///
168 /// # Errors
169 ///
170 /// Returns [`MIError::Hook`] if the layer is not in the cache or the
171 /// position is out of range.
172 pub fn top_attended_positions(
173 &self,
174 layer: usize,
175 from_position: usize,
176 k: usize,
177 ) -> Result<Vec<(usize, f32)>> {
178 let attn = self.attention_from_position(layer, from_position)?;
179 let mut indexed: Vec<(usize, f32)> = attn.into_iter().enumerate().collect();
180 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181 indexed.truncate(k);
182 Ok(indexed)
183 }
184
185 /// All cached patterns as a slice.
186 #[must_use]
187 pub fn patterns(&self) -> &[Tensor] {
188 &self.patterns
189 }
190}
191
192// ---------------------------------------------------------------------------
193// Tests
194// ---------------------------------------------------------------------------
195
196#[cfg(test)]
197#[allow(clippy::unwrap_used, clippy::expect_used, clippy::float_cmp)]
198mod tests {
199 use super::*;
200 use candle_core::Device;
201
202 /// Build a tiny attention cache for testing.
203 ///
204 /// Creates a single-layer cache with shape `[1, 2, 4, 4]`:
205 /// - 1 batch, 2 heads, 4 query positions, 4 key positions
206 /// - Head 0: uniform attention (0.25 everywhere)
207 /// - Head 1: identity-like (diagonal = 0.7, off-diagonal = 0.1)
208 fn sample_cache() -> AttentionCache {
209 let device = Device::Cpu;
210
211 // Head 0: uniform 0.25
212 // Head 1: diagonal-heavy
213 #[rustfmt::skip]
214 let data: Vec<f32> = vec![
215 // Head 0 (uniform)
216 0.25, 0.25, 0.25, 0.25,
217 0.25, 0.25, 0.25, 0.25,
218 0.25, 0.25, 0.25, 0.25,
219 0.25, 0.25, 0.25, 0.25,
220 // Head 1 (diagonal-heavy)
221 0.70, 0.10, 0.10, 0.10,
222 0.10, 0.70, 0.10, 0.10,
223 0.10, 0.10, 0.70, 0.10,
224 0.10, 0.10, 0.10, 0.70,
225 ];
226
227 let tensor = Tensor::from_vec(data, (1, 2, 4, 4), &device).unwrap();
228 let mut cache = AttentionCache::with_capacity(1);
229 cache.push(tensor);
230 cache
231 }
232
233 #[test]
234 fn empty_cache() {
235 let cache = AttentionCache::with_capacity(32);
236 assert_eq!(cache.n_layers(), 0);
237 assert!(cache.is_empty());
238 assert!(cache.get_layer(0).is_none());
239 }
240
241 #[test]
242 fn push_and_get_layer() {
243 let cache = sample_cache();
244 assert_eq!(cache.n_layers(), 1);
245 assert!(!cache.is_empty());
246
247 let layer0 = cache.get_layer(0).unwrap();
248 assert_eq!(layer0.dims(), &[1, 2, 4, 4]);
249 assert!(cache.get_layer(1).is_none());
250 }
251
252 #[test]
253 fn attention_from_position_values() {
254 let cache = sample_cache();
255
256 // Position 0: head0=[0.25,0.25,0.25,0.25], head1=[0.70,0.10,0.10,0.10]
257 // Average: [(0.25+0.70)/2, (0.25+0.10)/2, (0.25+0.10)/2, (0.25+0.10)/2]
258 // = [0.475, 0.175, 0.175, 0.175]
259 let attn = cache.attention_from_position(0, 0).unwrap();
260 assert_eq!(attn.len(), 4);
261 assert!((attn[0] - 0.475).abs() < 1e-5);
262 assert!((attn[1] - 0.175).abs() < 1e-5);
263 assert!((attn[2] - 0.175).abs() < 1e-5);
264 assert!((attn[3] - 0.175).abs() < 1e-5);
265 }
266
267 #[test]
268 fn attention_from_position_out_of_range() {
269 let cache = sample_cache();
270 assert!(cache.attention_from_position(0, 10).is_err());
271 assert!(cache.attention_from_position(5, 0).is_err());
272 }
273
274 #[test]
275 fn attention_to_position_values() {
276 let cache = sample_cache();
277
278 // Key position 0: each query row's column-0 value
279 // Head 0: all rows have 0.25 at col 0 → [0.25, 0.25, 0.25, 0.25]
280 // Head 1: rows have [0.70, 0.10, 0.10, 0.10] at col 0
281 // Average: [(0.25+0.70)/2, (0.25+0.10)/2, (0.25+0.10)/2, (0.25+0.10)/2]
282 // = [0.475, 0.175, 0.175, 0.175]
283 let attn = cache.attention_to_position(0, 0).unwrap();
284 assert_eq!(attn.len(), 4);
285 assert!((attn[0] - 0.475).abs() < 1e-5);
286 assert!((attn[1] - 0.175).abs() < 1e-5);
287 }
288
289 #[test]
290 fn attention_to_position_out_of_range() {
291 let cache = sample_cache();
292 assert!(cache.attention_to_position(0, 10).is_err());
293 assert!(cache.attention_to_position(5, 0).is_err());
294 }
295
296 #[test]
297 fn top_attended_positions_sorted() {
298 let cache = sample_cache();
299
300 // From position 0, the top-1 should be key position 0 (weight 0.475)
301 let top = cache.top_attended_positions(0, 0, 2).unwrap();
302 assert_eq!(top.len(), 2);
303 assert_eq!(top[0].0, 0); // position 0 has highest weight
304 assert!((top[0].1 - 0.475).abs() < 1e-5);
305 }
306
307 #[test]
308 fn top_attended_positions_k_larger_than_seq() {
309 let cache = sample_cache();
310
311 // k=100 but only 4 positions exist → returns all 4
312 let top = cache.top_attended_positions(0, 0, 100).unwrap();
313 assert_eq!(top.len(), 4);
314 }
315
316 #[test]
317 fn patterns_slice() {
318 let cache = sample_cache();
319 assert_eq!(cache.patterns().len(), 1);
320 }
321}