Skip to main content

candle_mi/
hooks.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Hook system for activation capture and intervention.
4//!
5//! Provides [`HookPoint`] (named locations in a forward pass),
6//! [`HookSpec`] (what to capture and where to intervene), and
7//! [`HookCache`] (captured tensors from a forward pass).
8//!
9//! See `design/hook-system.md` for the design rationale.
10
11use std::collections::{HashMap, HashSet};
12use std::fmt;
13use std::str::FromStr;
14
15use candle_core::Tensor;
16
17use crate::error::{MIError, Result};
18use crate::interp::intervention::{StateKnockoutSpec, StateSteeringSpec};
19
20// ---------------------------------------------------------------------------
21// HookPoint
22// ---------------------------------------------------------------------------
23
24/// Named location in a forward pass where activations can be captured
25/// or interventions applied.
26///
27/// Mirrors the `TransformerLens` hook point naming convention via
28/// [`Display`](std::fmt::Display) and [`FromStr`].
29///
30/// # String conversion
31///
32/// ```
33/// use candle_mi::HookPoint;
34///
35/// let hook = HookPoint::AttnPattern(5);
36/// assert_eq!(hook.to_string(), "blocks.5.attn.hook_pattern");
37///
38/// let parsed: HookPoint = "blocks.5.attn.hook_pattern".parse().unwrap();
39/// assert_eq!(parsed, hook);
40/// ```
41///
42/// Unknown strings parse as [`HookPoint::Custom`], providing an escape
43/// hatch for backend-specific hook points.
44#[non_exhaustive]
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum HookPoint {
47    // -- Embedding --
48    /// After token embedding (`hook_embed`).
49    Embed,
50
51    // -- Per-layer: transformer --
52    /// Residual stream before layer `i` (`blocks.{i}.hook_resid_pre`).
53    ResidPre(usize),
54    /// Query vectors in layer `i` (`blocks.{i}.attn.hook_q`).
55    AttnQ(usize),
56    /// Key vectors in layer `i` (`blocks.{i}.attn.hook_k`).
57    AttnK(usize),
58    /// Value vectors in layer `i` (`blocks.{i}.attn.hook_v`).
59    AttnV(usize),
60    /// Pre-softmax attention scores in layer `i` (`blocks.{i}.attn.hook_scores`).
61    AttnScores(usize),
62    /// Post-softmax attention pattern in layer `i` (`blocks.{i}.attn.hook_pattern`).
63    AttnPattern(usize),
64    /// Attention output in layer `i` (`blocks.{i}.hook_attn_out`).
65    AttnOut(usize),
66    /// Residual stream between attention and MLP in layer `i`
67    /// (`blocks.{i}.hook_resid_mid`).
68    ResidMid(usize),
69    /// MLP pre-activation in layer `i` (`blocks.{i}.mlp.hook_pre`).
70    MlpPre(usize),
71    /// MLP post-activation in layer `i` (`blocks.{i}.mlp.hook_post`).
72    MlpPost(usize),
73    /// MLP output in layer `i` (`blocks.{i}.hook_mlp_out`).
74    MlpOut(usize),
75    /// Residual stream after full layer `i` (`blocks.{i}.hook_resid_post`).
76    ResidPost(usize),
77
78    // -- Final --
79    /// After final layer norm (`hook_final_norm`).
80    FinalNorm,
81
82    // -- RWKV-specific --
83    /// RWKV recurrent state at layer `i` (`blocks.{i}.rwkv.hook_state`).
84    RwkvState(usize),
85    /// RWKV decay vector at layer `i` (`blocks.{i}.rwkv.hook_decay`).
86    RwkvDecay(usize),
87    /// RWKV effective attention at layer `i` (`blocks.{i}.rwkv.hook_effective_attn`).
88    ///
89    /// Shape: `[batch, heads, seq_query, seq_source]`.
90    /// Derived from the WKV recurrence by computing how much each
91    /// source position contributes to each query position's output.
92    /// Normalised via `ReLU` + L1.
93    RwkvEffectiveAttn(usize),
94
95    // -- Escape hatch --
96    /// Backend-specific hook point not covered by the standard enum.
97    Custom(String),
98}
99
100impl fmt::Display for HookPoint {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self {
103            Self::Embed => write!(f, "hook_embed"),
104            Self::ResidPre(i) => write!(f, "blocks.{i}.hook_resid_pre"),
105            Self::AttnQ(i) => write!(f, "blocks.{i}.attn.hook_q"),
106            Self::AttnK(i) => write!(f, "blocks.{i}.attn.hook_k"),
107            Self::AttnV(i) => write!(f, "blocks.{i}.attn.hook_v"),
108            Self::AttnScores(i) => write!(f, "blocks.{i}.attn.hook_scores"),
109            Self::AttnPattern(i) => write!(f, "blocks.{i}.attn.hook_pattern"),
110            Self::AttnOut(i) => write!(f, "blocks.{i}.hook_attn_out"),
111            Self::ResidMid(i) => write!(f, "blocks.{i}.hook_resid_mid"),
112            Self::MlpPre(i) => write!(f, "blocks.{i}.mlp.hook_pre"),
113            Self::MlpPost(i) => write!(f, "blocks.{i}.mlp.hook_post"),
114            Self::MlpOut(i) => write!(f, "blocks.{i}.hook_mlp_out"),
115            Self::ResidPost(i) => write!(f, "blocks.{i}.hook_resid_post"),
116            Self::FinalNorm => write!(f, "hook_final_norm"),
117            Self::RwkvState(i) => write!(f, "blocks.{i}.rwkv.hook_state"),
118            Self::RwkvDecay(i) => write!(f, "blocks.{i}.rwkv.hook_decay"),
119            Self::RwkvEffectiveAttn(i) => write!(f, "blocks.{i}.rwkv.hook_effective_attn"),
120            Self::Custom(s) => write!(f, "{s}"),
121        }
122    }
123}
124
125/// Parse a `TransformerLens`-style string into a [`HookPoint`].
126///
127/// Unknown strings produce [`HookPoint::Custom`] rather than an error.
128impl FromStr for HookPoint {
129    type Err = std::convert::Infallible;
130
131    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
132        Ok(parse_hook_string(s))
133    }
134}
135
136/// Allow `hooks.capture("blocks.5.attn.hook_pattern")` via `Into<HookPoint>`.
137impl From<&str> for HookPoint {
138    fn from(s: &str) -> Self {
139        parse_hook_string(s)
140    }
141}
142
143/// Parse a hook string, falling back to [`HookPoint::Custom`] for unknown patterns.
144fn parse_hook_string(s: &str) -> HookPoint {
145    match s {
146        "hook_embed" => return HookPoint::Embed,
147        "hook_final_norm" => return HookPoint::FinalNorm,
148        _ => {}
149    }
150
151    // Try "blocks.{layer}.{suffix}" pattern.
152    if let Some(rest) = s.strip_prefix("blocks.") {
153        if let Some((layer_str, suffix)) = rest.split_once('.') {
154            if let Ok(layer) = layer_str.parse::<usize>() {
155                return match suffix {
156                    "hook_resid_pre" => HookPoint::ResidPre(layer),
157                    "attn.hook_q" => HookPoint::AttnQ(layer),
158                    "attn.hook_k" => HookPoint::AttnK(layer),
159                    "attn.hook_v" => HookPoint::AttnV(layer),
160                    "attn.hook_scores" => HookPoint::AttnScores(layer),
161                    "attn.hook_pattern" => HookPoint::AttnPattern(layer),
162                    "hook_attn_out" => HookPoint::AttnOut(layer),
163                    "hook_resid_mid" => HookPoint::ResidMid(layer),
164                    "mlp.hook_pre" => HookPoint::MlpPre(layer),
165                    "mlp.hook_post" => HookPoint::MlpPost(layer),
166                    "hook_mlp_out" => HookPoint::MlpOut(layer),
167                    "hook_resid_post" => HookPoint::ResidPost(layer),
168                    "rwkv.hook_state" => HookPoint::RwkvState(layer),
169                    "rwkv.hook_decay" => HookPoint::RwkvDecay(layer),
170                    "rwkv.hook_effective_attn" => HookPoint::RwkvEffectiveAttn(layer),
171                    _ => HookPoint::Custom(s.to_string()),
172                };
173            }
174        }
175    }
176
177    HookPoint::Custom(s.to_string())
178}
179
180// ---------------------------------------------------------------------------
181// Intervention
182// ---------------------------------------------------------------------------
183
184/// An intervention to apply at a hook point during the forward pass.
185///
186/// Interventions modify activations in-place as they flow through the model.
187/// They are specified as part of a [`HookSpec`] and applied by the backend
188/// at the corresponding [`HookPoint`].
189#[non_exhaustive]
190#[derive(Debug, Clone)]
191pub enum Intervention {
192    /// Replace the tensor entirely with a provided value.
193    Replace(Tensor),
194
195    /// Add a vector to the activation (e.g., residual stream steering).
196    Add(Tensor),
197
198    /// Apply a pre-softmax knockout mask.
199    ///
200    /// The mask tensor contains `0.0` for positions to keep and
201    /// `-inf` for positions to knock out. Added to attention scores.
202    Knockout(Tensor),
203
204    /// Scale attention weights by a constant factor.
205    Scale(f64),
206
207    /// Zero the tensor at this hook point.
208    Zero,
209}
210
211// ---------------------------------------------------------------------------
212// Intervention application
213// ---------------------------------------------------------------------------
214
215/// Apply a single [`Intervention`] to a tensor.
216///
217/// Used by backend implementations at each hook point that supports
218/// interventions (e.g., Embed, `AttnScores`, `AttnPattern`).
219///
220/// # Shapes
221/// - `tensor`: any shape — the activation at the hook point.
222/// - returns: same shape as `tensor`.
223///
224/// # Errors
225///
226/// Returns [`MIError::Model`] if the underlying tensor operation fails.
227#[cfg(any(feature = "transformer", feature = "rwkv"))]
228pub(crate) fn apply_intervention(tensor: &Tensor, intervention: &Intervention) -> Result<Tensor> {
229    match intervention {
230        Intervention::Replace(replacement) => Ok(replacement.clone()),
231        Intervention::Add(delta) => {
232            // Convert delta to tensor's dtype if mismatched (e.g., F32 injection
233            // into BF16 forward pass). This supports CLT injection where steering
234            // vectors are accumulated in F32 for numerical stability.
235            let delta = if delta.dtype() == tensor.dtype() {
236                delta
237            } else {
238                &delta.to_dtype(tensor.dtype())?
239            };
240            Ok(tensor.broadcast_add(delta)?)
241        }
242        Intervention::Knockout(mask) => Ok(tensor.broadcast_add(mask)?),
243        Intervention::Scale(factor) => Ok((tensor * *factor)?),
244        Intervention::Zero => Ok(tensor.zeros_like()?),
245    }
246}
247
248// ---------------------------------------------------------------------------
249// HookSpec
250// ---------------------------------------------------------------------------
251
252/// Declares which activations to capture and which interventions to apply.
253///
254/// Passed to [`MIBackend::forward`](crate::MIBackend::forward). When empty,
255/// the forward pass has zero overhead (no clones, no extra allocations).
256///
257/// # Example
258///
259/// ```
260/// use candle_mi::{HookPoint, HookSpec};
261///
262/// let mut hooks = HookSpec::new();
263/// hooks.capture(HookPoint::AttnPattern(5))
264///      .capture("blocks.5.hook_resid_post");
265/// ```
266#[derive(Debug, Clone, Default)]
267pub struct HookSpec {
268    /// Hook points to capture during the forward pass.
269    captures: HashSet<HookPoint>,
270    /// Interventions to apply, stored as (`hook_point`, intervention) pairs.
271    interventions: Vec<(HookPoint, Intervention)>,
272    /// RWKV state knockout specification (skip kv write at specified positions).
273    state_knockout: Option<StateKnockoutSpec>,
274    /// RWKV state steering specification (scale kv write at specified positions).
275    state_steering: Option<StateSteeringSpec>,
276}
277
278impl HookSpec {
279    /// Create an empty hook specification (no captures, no interventions).
280    #[must_use]
281    pub fn new() -> Self {
282        Self::default()
283    }
284
285    /// Request capture of the activation at the given hook point.
286    pub fn capture<H: Into<HookPoint>>(&mut self, hook: H) -> &mut Self {
287        self.captures.insert(hook.into());
288        self
289    }
290
291    /// Register an intervention at the given hook point.
292    pub fn intervene<H: Into<HookPoint>>(
293        &mut self,
294        hook: H,
295        intervention: Intervention,
296    ) -> &mut Self {
297        self.interventions.push((hook.into(), intervention));
298        self
299    }
300
301    /// Check whether a specific hook point should be captured.
302    #[must_use]
303    pub fn is_captured(&self, hook: &HookPoint) -> bool {
304        self.captures.contains(hook)
305    }
306
307    /// Check whether this spec has no captures, no interventions, and no
308    /// state specs (knockout/steering).
309    #[must_use]
310    pub fn is_empty(&self) -> bool {
311        self.captures.is_empty()
312            && self.interventions.is_empty()
313            && self.state_knockout.is_none()
314            && self.state_steering.is_none()
315    }
316
317    /// Number of requested captures.
318    #[must_use]
319    pub fn num_captures(&self) -> usize {
320        self.captures.len()
321    }
322
323    /// Number of registered interventions.
324    #[must_use]
325    pub const fn num_interventions(&self) -> usize {
326        self.interventions.len()
327    }
328
329    /// Iterate over interventions registered at a specific hook point.
330    pub fn interventions_at(&self, hook: &HookPoint) -> impl Iterator<Item = &Intervention> {
331        self.interventions
332            .iter()
333            .filter(move |(h, _)| h == hook)
334            .map(|(_, intervention)| intervention)
335    }
336
337    /// Check whether any intervention targets the given hook point.
338    #[must_use]
339    pub fn has_intervention_at(&self, hook: &HookPoint) -> bool {
340        self.interventions.iter().any(|(h, _)| h == hook)
341    }
342
343    /// Set an RWKV state knockout specification.
344    ///
345    /// At specified token positions, the WKV recurrence skips the kv write,
346    /// effectively making those tokens invisible to all future positions.
347    pub fn set_state_knockout(&mut self, spec: StateKnockoutSpec) -> &mut Self {
348        self.state_knockout = Some(spec);
349        self
350    }
351
352    /// Set an RWKV state steering specification.
353    ///
354    /// At specified token positions, the WKV recurrence scales the kv write
355    /// by the given factor, amplifying or dampening the token's contribution.
356    pub fn set_state_steering(&mut self, spec: StateSteeringSpec) -> &mut Self {
357        self.state_steering = Some(spec);
358        self
359    }
360
361    /// Get the state knockout specification, if any.
362    #[must_use]
363    pub const fn state_knockout(&self) -> Option<&StateKnockoutSpec> {
364        self.state_knockout.as_ref()
365    }
366
367    /// Get the state steering specification, if any.
368    #[must_use]
369    pub const fn state_steering(&self) -> Option<&StateSteeringSpec> {
370        self.state_steering.as_ref()
371    }
372
373    /// Merge all captures and interventions from another [`HookSpec`] into this one.
374    ///
375    /// Useful for combining multiple intervention sources (e.g., suppress +
376    /// inject in CLT steering).
377    pub fn extend(&mut self, other: &Self) -> &mut Self {
378        self.captures.extend(other.captures.iter().cloned());
379        self.interventions
380            .extend(other.interventions.iter().cloned());
381        self
382    }
383}
384
385// ---------------------------------------------------------------------------
386// HookCache
387// ---------------------------------------------------------------------------
388
389/// Tensors captured during a forward pass, plus the output logits.
390///
391/// Returned by [`MIBackend::forward`](crate::MIBackend::forward). Use
392/// [`get`](Self::get) to retrieve activations at specific hook points.
393///
394/// # Example
395///
396/// ```
397/// use candle_mi::{HookCache, HookPoint};
398/// use candle_core::{Device, Tensor};
399///
400/// let logits = Tensor::zeros((1, 10, 32000), candle_core::DType::F32, &Device::Cpu).unwrap();
401/// let mut cache = HookCache::new(logits);
402///
403/// // Store a captured activation
404/// let pattern = Tensor::zeros((1, 8, 10, 10), candle_core::DType::F32, &Device::Cpu).unwrap();
405/// cache.store(HookPoint::AttnPattern(5), pattern);
406///
407/// // Retrieve captured activations
408/// let output = cache.output();
409/// let attn = cache.get(&HookPoint::AttnPattern(5)).unwrap();
410/// ```
411#[derive(Debug)]
412pub struct HookCache {
413    /// Output tensor from the forward pass (typically logits).
414    output: Tensor,
415    /// Captured activations keyed by hook point.
416    captures: HashMap<HookPoint, Tensor>,
417}
418
419impl HookCache {
420    /// Create a new cache with the given output tensor and no captures.
421    #[must_use]
422    pub fn new(output: Tensor) -> Self {
423        Self {
424            output,
425            captures: HashMap::new(),
426        }
427    }
428
429    /// The output tensor from the forward pass.
430    #[must_use]
431    pub const fn output(&self) -> &Tensor {
432        &self.output
433    }
434
435    /// Consume the cache and return the output tensor.
436    #[must_use]
437    pub fn into_output(self) -> Tensor {
438        self.output
439    }
440
441    /// Retrieve a captured tensor by hook point.
442    #[must_use]
443    pub fn get(&self, hook: &HookPoint) -> Option<&Tensor> {
444        self.captures.get(hook)
445    }
446
447    /// Retrieve a captured tensor, returning an error if not found.
448    ///
449    /// # Errors
450    ///
451    /// Returns [`MIError::Hook`] if the hook point was not captured.
452    pub fn require(&self, hook: &HookPoint) -> Result<&Tensor> {
453        self.captures
454            .get(hook)
455            .ok_or_else(|| MIError::Hook(format!("hook point `{hook}` was not captured")))
456    }
457
458    /// Store a captured activation. Called by backend implementations.
459    pub fn store(&mut self, hook: HookPoint, tensor: Tensor) {
460        self.captures.insert(hook, tensor);
461    }
462
463    /// Replace the output tensor (e.g., after computing final logits).
464    ///
465    /// This allows the forward pass to collect captures into a cache
466    /// initialized with a placeholder, then set the real output at the end.
467    pub fn set_output(&mut self, output: Tensor) {
468        self.output = output;
469    }
470
471    /// Number of captured tensors (excludes the output).
472    #[must_use]
473    pub fn num_captures(&self) -> usize {
474        self.captures.len()
475    }
476}
477
478// ---------------------------------------------------------------------------
479// Tests
480// ---------------------------------------------------------------------------
481
482#[cfg(test)]
483#[allow(clippy::unwrap_used, clippy::expect_used)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn hook_point_display_roundtrip() {
489        let cases: Vec<(HookPoint, &str)> = vec![
490            (HookPoint::Embed, "hook_embed"),
491            (HookPoint::FinalNorm, "hook_final_norm"),
492            (HookPoint::ResidPre(0), "blocks.0.hook_resid_pre"),
493            (HookPoint::AttnQ(3), "blocks.3.attn.hook_q"),
494            (HookPoint::AttnK(3), "blocks.3.attn.hook_k"),
495            (HookPoint::AttnV(3), "blocks.3.attn.hook_v"),
496            (HookPoint::AttnScores(7), "blocks.7.attn.hook_scores"),
497            (HookPoint::AttnPattern(5), "blocks.5.attn.hook_pattern"),
498            (HookPoint::AttnOut(2), "blocks.2.hook_attn_out"),
499            (HookPoint::ResidMid(11), "blocks.11.hook_resid_mid"),
500            (HookPoint::MlpPre(1), "blocks.1.mlp.hook_pre"),
501            (HookPoint::MlpPost(1), "blocks.1.mlp.hook_post"),
502            (HookPoint::MlpOut(4), "blocks.4.hook_mlp_out"),
503            (HookPoint::ResidPost(9), "blocks.9.hook_resid_post"),
504            (HookPoint::RwkvState(6), "blocks.6.rwkv.hook_state"),
505            (HookPoint::RwkvDecay(6), "blocks.6.rwkv.hook_decay"),
506            (
507                HookPoint::RwkvEffectiveAttn(6),
508                "blocks.6.rwkv.hook_effective_attn",
509            ),
510        ];
511
512        for (hook, expected_str) in cases {
513            // Display
514            assert_eq!(
515                hook.to_string(),
516                expected_str,
517                "Display failed for {hook:?}"
518            );
519            // FromStr roundtrip
520            let parsed: HookPoint = expected_str.parse().unwrap();
521            assert_eq!(parsed, hook, "FromStr failed for {expected_str:?}");
522            // From<&str>
523            let from_str: HookPoint = HookPoint::from(expected_str);
524            assert_eq!(from_str, hook, "From<&str> failed for {expected_str:?}");
525        }
526    }
527
528    #[test]
529    fn unknown_string_becomes_custom() {
530        let hook: HookPoint = "some.unknown.hook".parse().unwrap();
531        assert_eq!(hook, HookPoint::Custom("some.unknown.hook".to_string()));
532    }
533
534    #[test]
535    fn hook_spec_capture_and_query() {
536        let mut spec = HookSpec::new();
537        assert!(spec.is_empty());
538
539        spec.capture(HookPoint::AttnPattern(5));
540        spec.capture("blocks.3.hook_resid_post");
541
542        assert!(!spec.is_empty());
543        assert_eq!(spec.num_captures(), 2);
544        assert!(spec.is_captured(&HookPoint::AttnPattern(5)));
545        assert!(spec.is_captured(&HookPoint::ResidPost(3)));
546        assert!(!spec.is_captured(&HookPoint::Embed));
547    }
548
549    #[test]
550    fn hook_spec_intervention_query() {
551        let mut spec = HookSpec::new();
552        spec.intervene(HookPoint::AttnScores(5), Intervention::Zero);
553        spec.intervene(HookPoint::AttnScores(5), Intervention::Scale(2.0));
554        spec.intervene(HookPoint::ResidPost(10), Intervention::Zero);
555
556        assert_eq!(spec.num_interventions(), 3);
557        assert!(spec.has_intervention_at(&HookPoint::AttnScores(5)));
558        assert!(!spec.has_intervention_at(&HookPoint::Embed));
559
560        let at_5: Vec<_> = spec.interventions_at(&HookPoint::AttnScores(5)).collect();
561        assert_eq!(at_5.len(), 2);
562    }
563}