candle_mi/hooks.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Hook system for activation capture and intervention.
4//!
5//! Provides [`HookPoint`] (named locations in a forward pass),
6//! [`HookSpec`] (what to capture and where to intervene), and
7//! [`HookCache`] (captured tensors from a forward pass).
8//!
9//! See `design/hook-system.md` for the design rationale.
10
11use std::collections::{HashMap, HashSet};
12use std::fmt;
13use std::str::FromStr;
14
15use candle_core::Tensor;
16
17use crate::error::{MIError, Result};
18use crate::interp::intervention::{StateKnockoutSpec, StateSteeringSpec};
19
20// ---------------------------------------------------------------------------
21// HookPoint
22// ---------------------------------------------------------------------------
23
24/// Named location in a forward pass where activations can be captured
25/// or interventions applied.
26///
27/// Mirrors the `TransformerLens` hook point naming convention via
28/// [`Display`](std::fmt::Display) and [`FromStr`].
29///
30/// # String conversion
31///
32/// ```
33/// use candle_mi::HookPoint;
34///
35/// let hook = HookPoint::AttnPattern(5);
36/// assert_eq!(hook.to_string(), "blocks.5.attn.hook_pattern");
37///
38/// let parsed: HookPoint = "blocks.5.attn.hook_pattern".parse().unwrap();
39/// assert_eq!(parsed, hook);
40/// ```
41///
42/// Unknown strings parse as [`HookPoint::Custom`], providing an escape
43/// hatch for backend-specific hook points.
44#[non_exhaustive]
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum HookPoint {
47 // -- Embedding --
48 /// After token embedding (`hook_embed`).
49 Embed,
50
51 // -- Per-layer: transformer --
52 /// Residual stream before layer `i` (`blocks.{i}.hook_resid_pre`).
53 ResidPre(usize),
54 /// Query vectors in layer `i` (`blocks.{i}.attn.hook_q`).
55 AttnQ(usize),
56 /// Key vectors in layer `i` (`blocks.{i}.attn.hook_k`).
57 AttnK(usize),
58 /// Value vectors in layer `i` (`blocks.{i}.attn.hook_v`).
59 AttnV(usize),
60 /// Pre-softmax attention scores in layer `i` (`blocks.{i}.attn.hook_scores`).
61 AttnScores(usize),
62 /// Post-softmax attention pattern in layer `i` (`blocks.{i}.attn.hook_pattern`).
63 AttnPattern(usize),
64 /// Attention output in layer `i` (`blocks.{i}.hook_attn_out`).
65 AttnOut(usize),
66 /// Residual stream between attention and MLP in layer `i`
67 /// (`blocks.{i}.hook_resid_mid`).
68 ResidMid(usize),
69 /// MLP pre-activation in layer `i` (`blocks.{i}.mlp.hook_pre`).
70 MlpPre(usize),
71 /// MLP post-activation in layer `i` (`blocks.{i}.mlp.hook_post`).
72 MlpPost(usize),
73 /// MLP output in layer `i` (`blocks.{i}.hook_mlp_out`).
74 MlpOut(usize),
75 /// Residual stream after full layer `i` (`blocks.{i}.hook_resid_post`).
76 ResidPost(usize),
77
78 // -- Final --
79 /// After final layer norm (`hook_final_norm`).
80 FinalNorm,
81
82 // -- RWKV-specific --
83 /// RWKV recurrent state at layer `i` (`blocks.{i}.rwkv.hook_state`).
84 RwkvState(usize),
85 /// RWKV decay vector at layer `i` (`blocks.{i}.rwkv.hook_decay`).
86 RwkvDecay(usize),
87 /// RWKV effective attention at layer `i` (`blocks.{i}.rwkv.hook_effective_attn`).
88 ///
89 /// Shape: `[batch, heads, seq_query, seq_source]`.
90 /// Derived from the WKV recurrence by computing how much each
91 /// source position contributes to each query position's output.
92 /// Normalised via `ReLU` + L1.
93 RwkvEffectiveAttn(usize),
94
95 // -- Escape hatch --
96 /// Backend-specific hook point not covered by the standard enum.
97 Custom(String),
98}
99
100impl fmt::Display for HookPoint {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self {
103 Self::Embed => write!(f, "hook_embed"),
104 Self::ResidPre(i) => write!(f, "blocks.{i}.hook_resid_pre"),
105 Self::AttnQ(i) => write!(f, "blocks.{i}.attn.hook_q"),
106 Self::AttnK(i) => write!(f, "blocks.{i}.attn.hook_k"),
107 Self::AttnV(i) => write!(f, "blocks.{i}.attn.hook_v"),
108 Self::AttnScores(i) => write!(f, "blocks.{i}.attn.hook_scores"),
109 Self::AttnPattern(i) => write!(f, "blocks.{i}.attn.hook_pattern"),
110 Self::AttnOut(i) => write!(f, "blocks.{i}.hook_attn_out"),
111 Self::ResidMid(i) => write!(f, "blocks.{i}.hook_resid_mid"),
112 Self::MlpPre(i) => write!(f, "blocks.{i}.mlp.hook_pre"),
113 Self::MlpPost(i) => write!(f, "blocks.{i}.mlp.hook_post"),
114 Self::MlpOut(i) => write!(f, "blocks.{i}.hook_mlp_out"),
115 Self::ResidPost(i) => write!(f, "blocks.{i}.hook_resid_post"),
116 Self::FinalNorm => write!(f, "hook_final_norm"),
117 Self::RwkvState(i) => write!(f, "blocks.{i}.rwkv.hook_state"),
118 Self::RwkvDecay(i) => write!(f, "blocks.{i}.rwkv.hook_decay"),
119 Self::RwkvEffectiveAttn(i) => write!(f, "blocks.{i}.rwkv.hook_effective_attn"),
120 Self::Custom(s) => write!(f, "{s}"),
121 }
122 }
123}
124
125/// Parse a `TransformerLens`-style string into a [`HookPoint`].
126///
127/// Unknown strings produce [`HookPoint::Custom`] rather than an error.
128impl FromStr for HookPoint {
129 type Err = std::convert::Infallible;
130
131 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
132 Ok(parse_hook_string(s))
133 }
134}
135
136/// Allow `hooks.capture("blocks.5.attn.hook_pattern")` via `Into<HookPoint>`.
137impl From<&str> for HookPoint {
138 fn from(s: &str) -> Self {
139 parse_hook_string(s)
140 }
141}
142
143/// Parse a hook string, falling back to [`HookPoint::Custom`] for unknown patterns.
144fn parse_hook_string(s: &str) -> HookPoint {
145 match s {
146 "hook_embed" => return HookPoint::Embed,
147 "hook_final_norm" => return HookPoint::FinalNorm,
148 _ => {}
149 }
150
151 // Try "blocks.{layer}.{suffix}" pattern.
152 if let Some(rest) = s.strip_prefix("blocks.") {
153 if let Some((layer_str, suffix)) = rest.split_once('.') {
154 if let Ok(layer) = layer_str.parse::<usize>() {
155 return match suffix {
156 "hook_resid_pre" => HookPoint::ResidPre(layer),
157 "attn.hook_q" => HookPoint::AttnQ(layer),
158 "attn.hook_k" => HookPoint::AttnK(layer),
159 "attn.hook_v" => HookPoint::AttnV(layer),
160 "attn.hook_scores" => HookPoint::AttnScores(layer),
161 "attn.hook_pattern" => HookPoint::AttnPattern(layer),
162 "hook_attn_out" => HookPoint::AttnOut(layer),
163 "hook_resid_mid" => HookPoint::ResidMid(layer),
164 "mlp.hook_pre" => HookPoint::MlpPre(layer),
165 "mlp.hook_post" => HookPoint::MlpPost(layer),
166 "hook_mlp_out" => HookPoint::MlpOut(layer),
167 "hook_resid_post" => HookPoint::ResidPost(layer),
168 "rwkv.hook_state" => HookPoint::RwkvState(layer),
169 "rwkv.hook_decay" => HookPoint::RwkvDecay(layer),
170 "rwkv.hook_effective_attn" => HookPoint::RwkvEffectiveAttn(layer),
171 _ => HookPoint::Custom(s.to_string()),
172 };
173 }
174 }
175 }
176
177 HookPoint::Custom(s.to_string())
178}
179
180// ---------------------------------------------------------------------------
181// Intervention
182// ---------------------------------------------------------------------------
183
184/// An intervention to apply at a hook point during the forward pass.
185///
186/// Interventions modify activations in-place as they flow through the model.
187/// They are specified as part of a [`HookSpec`] and applied by the backend
188/// at the corresponding [`HookPoint`].
189#[non_exhaustive]
190#[derive(Debug, Clone)]
191pub enum Intervention {
192 /// Replace the tensor entirely with a provided value.
193 Replace(Tensor),
194
195 /// Add a vector to the activation (e.g., residual stream steering).
196 Add(Tensor),
197
198 /// Apply a pre-softmax knockout mask.
199 ///
200 /// The mask tensor contains `0.0` for positions to keep and
201 /// `-inf` for positions to knock out. Added to attention scores.
202 Knockout(Tensor),
203
204 /// Scale attention weights by a constant factor.
205 Scale(f64),
206
207 /// Zero the tensor at this hook point.
208 Zero,
209}
210
211// ---------------------------------------------------------------------------
212// Intervention application
213// ---------------------------------------------------------------------------
214
215/// Apply a single [`Intervention`] to a tensor.
216///
217/// Used by backend implementations at each hook point that supports
218/// interventions (e.g., Embed, `AttnScores`, `AttnPattern`).
219///
220/// # Shapes
221/// - `tensor`: any shape — the activation at the hook point.
222/// - returns: same shape as `tensor`.
223///
224/// # Errors
225///
226/// Returns [`MIError::Model`] if the underlying tensor operation fails.
227#[cfg(any(feature = "transformer", feature = "rwkv"))]
228pub(crate) fn apply_intervention(tensor: &Tensor, intervention: &Intervention) -> Result<Tensor> {
229 match intervention {
230 Intervention::Replace(replacement) => Ok(replacement.clone()),
231 Intervention::Add(delta) => {
232 // Convert delta to tensor's dtype if mismatched (e.g., F32 injection
233 // into BF16 forward pass). This supports CLT injection where steering
234 // vectors are accumulated in F32 for numerical stability.
235 let delta = if delta.dtype() == tensor.dtype() {
236 delta
237 } else {
238 &delta.to_dtype(tensor.dtype())?
239 };
240 Ok(tensor.broadcast_add(delta)?)
241 }
242 Intervention::Knockout(mask) => Ok(tensor.broadcast_add(mask)?),
243 Intervention::Scale(factor) => Ok((tensor * *factor)?),
244 Intervention::Zero => Ok(tensor.zeros_like()?),
245 }
246}
247
248// ---------------------------------------------------------------------------
249// HookSpec
250// ---------------------------------------------------------------------------
251
252/// Declares which activations to capture and which interventions to apply.
253///
254/// Passed to [`MIBackend::forward`](crate::MIBackend::forward). When empty,
255/// the forward pass has zero overhead (no clones, no extra allocations).
256///
257/// # Example
258///
259/// ```
260/// use candle_mi::{HookPoint, HookSpec};
261///
262/// let mut hooks = HookSpec::new();
263/// hooks.capture(HookPoint::AttnPattern(5))
264/// .capture("blocks.5.hook_resid_post");
265/// ```
266#[derive(Debug, Clone, Default)]
267pub struct HookSpec {
268 /// Hook points to capture during the forward pass.
269 captures: HashSet<HookPoint>,
270 /// Interventions to apply, stored as (`hook_point`, intervention) pairs.
271 interventions: Vec<(HookPoint, Intervention)>,
272 /// RWKV state knockout specification (skip kv write at specified positions).
273 state_knockout: Option<StateKnockoutSpec>,
274 /// RWKV state steering specification (scale kv write at specified positions).
275 state_steering: Option<StateSteeringSpec>,
276}
277
278impl HookSpec {
279 /// Create an empty hook specification (no captures, no interventions).
280 #[must_use]
281 pub fn new() -> Self {
282 Self::default()
283 }
284
285 /// Request capture of the activation at the given hook point.
286 pub fn capture<H: Into<HookPoint>>(&mut self, hook: H) -> &mut Self {
287 self.captures.insert(hook.into());
288 self
289 }
290
291 /// Register an intervention at the given hook point.
292 pub fn intervene<H: Into<HookPoint>>(
293 &mut self,
294 hook: H,
295 intervention: Intervention,
296 ) -> &mut Self {
297 self.interventions.push((hook.into(), intervention));
298 self
299 }
300
301 /// Check whether a specific hook point should be captured.
302 #[must_use]
303 pub fn is_captured(&self, hook: &HookPoint) -> bool {
304 self.captures.contains(hook)
305 }
306
307 /// Check whether this spec has no captures, no interventions, and no
308 /// state specs (knockout/steering).
309 #[must_use]
310 pub fn is_empty(&self) -> bool {
311 self.captures.is_empty()
312 && self.interventions.is_empty()
313 && self.state_knockout.is_none()
314 && self.state_steering.is_none()
315 }
316
317 /// Number of requested captures.
318 #[must_use]
319 pub fn num_captures(&self) -> usize {
320 self.captures.len()
321 }
322
323 /// Number of registered interventions.
324 #[must_use]
325 pub const fn num_interventions(&self) -> usize {
326 self.interventions.len()
327 }
328
329 /// Iterate over interventions registered at a specific hook point.
330 pub fn interventions_at(&self, hook: &HookPoint) -> impl Iterator<Item = &Intervention> {
331 self.interventions
332 .iter()
333 .filter(move |(h, _)| h == hook)
334 .map(|(_, intervention)| intervention)
335 }
336
337 /// Check whether any intervention targets the given hook point.
338 #[must_use]
339 pub fn has_intervention_at(&self, hook: &HookPoint) -> bool {
340 self.interventions.iter().any(|(h, _)| h == hook)
341 }
342
343 /// Set an RWKV state knockout specification.
344 ///
345 /// At specified token positions, the WKV recurrence skips the kv write,
346 /// effectively making those tokens invisible to all future positions.
347 pub fn set_state_knockout(&mut self, spec: StateKnockoutSpec) -> &mut Self {
348 self.state_knockout = Some(spec);
349 self
350 }
351
352 /// Set an RWKV state steering specification.
353 ///
354 /// At specified token positions, the WKV recurrence scales the kv write
355 /// by the given factor, amplifying or dampening the token's contribution.
356 pub fn set_state_steering(&mut self, spec: StateSteeringSpec) -> &mut Self {
357 self.state_steering = Some(spec);
358 self
359 }
360
361 /// Get the state knockout specification, if any.
362 #[must_use]
363 pub const fn state_knockout(&self) -> Option<&StateKnockoutSpec> {
364 self.state_knockout.as_ref()
365 }
366
367 /// Get the state steering specification, if any.
368 #[must_use]
369 pub const fn state_steering(&self) -> Option<&StateSteeringSpec> {
370 self.state_steering.as_ref()
371 }
372
373 /// Merge all captures and interventions from another [`HookSpec`] into this one.
374 ///
375 /// Useful for combining multiple intervention sources (e.g., suppress +
376 /// inject in CLT steering).
377 pub fn extend(&mut self, other: &Self) -> &mut Self {
378 self.captures.extend(other.captures.iter().cloned());
379 self.interventions
380 .extend(other.interventions.iter().cloned());
381 self
382 }
383}
384
385// ---------------------------------------------------------------------------
386// HookCache
387// ---------------------------------------------------------------------------
388
389/// Tensors captured during a forward pass, plus the output logits.
390///
391/// Returned by [`MIBackend::forward`](crate::MIBackend::forward). Use
392/// [`get`](Self::get) to retrieve activations at specific hook points.
393///
394/// # Example
395///
396/// ```
397/// use candle_mi::{HookCache, HookPoint};
398/// use candle_core::{Device, Tensor};
399///
400/// let logits = Tensor::zeros((1, 10, 32000), candle_core::DType::F32, &Device::Cpu).unwrap();
401/// let mut cache = HookCache::new(logits);
402///
403/// // Store a captured activation
404/// let pattern = Tensor::zeros((1, 8, 10, 10), candle_core::DType::F32, &Device::Cpu).unwrap();
405/// cache.store(HookPoint::AttnPattern(5), pattern);
406///
407/// // Retrieve captured activations
408/// let output = cache.output();
409/// let attn = cache.get(&HookPoint::AttnPattern(5)).unwrap();
410/// ```
411#[derive(Debug)]
412pub struct HookCache {
413 /// Output tensor from the forward pass (typically logits).
414 output: Tensor,
415 /// Captured activations keyed by hook point.
416 captures: HashMap<HookPoint, Tensor>,
417}
418
419impl HookCache {
420 /// Create a new cache with the given output tensor and no captures.
421 #[must_use]
422 pub fn new(output: Tensor) -> Self {
423 Self {
424 output,
425 captures: HashMap::new(),
426 }
427 }
428
429 /// The output tensor from the forward pass.
430 #[must_use]
431 pub const fn output(&self) -> &Tensor {
432 &self.output
433 }
434
435 /// Consume the cache and return the output tensor.
436 #[must_use]
437 pub fn into_output(self) -> Tensor {
438 self.output
439 }
440
441 /// Retrieve a captured tensor by hook point.
442 #[must_use]
443 pub fn get(&self, hook: &HookPoint) -> Option<&Tensor> {
444 self.captures.get(hook)
445 }
446
447 /// Retrieve a captured tensor, returning an error if not found.
448 ///
449 /// # Errors
450 ///
451 /// Returns [`MIError::Hook`] if the hook point was not captured.
452 pub fn require(&self, hook: &HookPoint) -> Result<&Tensor> {
453 self.captures
454 .get(hook)
455 .ok_or_else(|| MIError::Hook(format!("hook point `{hook}` was not captured")))
456 }
457
458 /// Store a captured activation. Called by backend implementations.
459 pub fn store(&mut self, hook: HookPoint, tensor: Tensor) {
460 self.captures.insert(hook, tensor);
461 }
462
463 /// Replace the output tensor (e.g., after computing final logits).
464 ///
465 /// This allows the forward pass to collect captures into a cache
466 /// initialized with a placeholder, then set the real output at the end.
467 pub fn set_output(&mut self, output: Tensor) {
468 self.output = output;
469 }
470
471 /// Number of captured tensors (excludes the output).
472 #[must_use]
473 pub fn num_captures(&self) -> usize {
474 self.captures.len()
475 }
476}
477
478// ---------------------------------------------------------------------------
479// Tests
480// ---------------------------------------------------------------------------
481
482#[cfg(test)]
483#[allow(clippy::unwrap_used, clippy::expect_used)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn hook_point_display_roundtrip() {
489 let cases: Vec<(HookPoint, &str)> = vec![
490 (HookPoint::Embed, "hook_embed"),
491 (HookPoint::FinalNorm, "hook_final_norm"),
492 (HookPoint::ResidPre(0), "blocks.0.hook_resid_pre"),
493 (HookPoint::AttnQ(3), "blocks.3.attn.hook_q"),
494 (HookPoint::AttnK(3), "blocks.3.attn.hook_k"),
495 (HookPoint::AttnV(3), "blocks.3.attn.hook_v"),
496 (HookPoint::AttnScores(7), "blocks.7.attn.hook_scores"),
497 (HookPoint::AttnPattern(5), "blocks.5.attn.hook_pattern"),
498 (HookPoint::AttnOut(2), "blocks.2.hook_attn_out"),
499 (HookPoint::ResidMid(11), "blocks.11.hook_resid_mid"),
500 (HookPoint::MlpPre(1), "blocks.1.mlp.hook_pre"),
501 (HookPoint::MlpPost(1), "blocks.1.mlp.hook_post"),
502 (HookPoint::MlpOut(4), "blocks.4.hook_mlp_out"),
503 (HookPoint::ResidPost(9), "blocks.9.hook_resid_post"),
504 (HookPoint::RwkvState(6), "blocks.6.rwkv.hook_state"),
505 (HookPoint::RwkvDecay(6), "blocks.6.rwkv.hook_decay"),
506 (
507 HookPoint::RwkvEffectiveAttn(6),
508 "blocks.6.rwkv.hook_effective_attn",
509 ),
510 ];
511
512 for (hook, expected_str) in cases {
513 // Display
514 assert_eq!(
515 hook.to_string(),
516 expected_str,
517 "Display failed for {hook:?}"
518 );
519 // FromStr roundtrip
520 let parsed: HookPoint = expected_str.parse().unwrap();
521 assert_eq!(parsed, hook, "FromStr failed for {expected_str:?}");
522 // From<&str>
523 let from_str: HookPoint = HookPoint::from(expected_str);
524 assert_eq!(from_str, hook, "From<&str> failed for {expected_str:?}");
525 }
526 }
527
528 #[test]
529 fn unknown_string_becomes_custom() {
530 let hook: HookPoint = "some.unknown.hook".parse().unwrap();
531 assert_eq!(hook, HookPoint::Custom("some.unknown.hook".to_string()));
532 }
533
534 #[test]
535 fn hook_spec_capture_and_query() {
536 let mut spec = HookSpec::new();
537 assert!(spec.is_empty());
538
539 spec.capture(HookPoint::AttnPattern(5));
540 spec.capture("blocks.3.hook_resid_post");
541
542 assert!(!spec.is_empty());
543 assert_eq!(spec.num_captures(), 2);
544 assert!(spec.is_captured(&HookPoint::AttnPattern(5)));
545 assert!(spec.is_captured(&HookPoint::ResidPost(3)));
546 assert!(!spec.is_captured(&HookPoint::Embed));
547 }
548
549 #[test]
550 fn hook_spec_intervention_query() {
551 let mut spec = HookSpec::new();
552 spec.intervene(HookPoint::AttnScores(5), Intervention::Zero);
553 spec.intervene(HookPoint::AttnScores(5), Intervention::Scale(2.0));
554 spec.intervene(HookPoint::ResidPost(10), Intervention::Zero);
555
556 assert_eq!(spec.num_interventions(), 3);
557 assert!(spec.has_intervention_at(&HookPoint::AttnScores(5)));
558 assert!(!spec.has_intervention_at(&HookPoint::Embed));
559
560 let at_5: Vec<_> = spec.interventions_at(&HookPoint::AttnScores(5)).collect();
561 assert_eq!(at_5.len(), 2);
562 }
563}