use std::collections::{HashMap, HashSet};
use std::fmt;
use std::str::FromStr;
use candle_core::Tensor;
use crate::error::{MIError, Result};
use crate::interp::intervention::{StateKnockoutSpec, StateSteeringSpec};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum HookPoint {
Embed,
ResidPre(usize),
AttnQ(usize),
AttnK(usize),
AttnV(usize),
AttnScores(usize),
AttnPattern(usize),
AttnOut(usize),
ResidMid(usize),
MlpPre(usize),
MlpPost(usize),
MlpOut(usize),
ResidPost(usize),
FinalNorm,
RwkvState(usize),
RwkvDecay(usize),
RwkvEffectiveAttn(usize),
Custom(String),
}
impl fmt::Display for HookPoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Embed => write!(f, "hook_embed"),
Self::ResidPre(i) => write!(f, "blocks.{i}.hook_resid_pre"),
Self::AttnQ(i) => write!(f, "blocks.{i}.attn.hook_q"),
Self::AttnK(i) => write!(f, "blocks.{i}.attn.hook_k"),
Self::AttnV(i) => write!(f, "blocks.{i}.attn.hook_v"),
Self::AttnScores(i) => write!(f, "blocks.{i}.attn.hook_scores"),
Self::AttnPattern(i) => write!(f, "blocks.{i}.attn.hook_pattern"),
Self::AttnOut(i) => write!(f, "blocks.{i}.hook_attn_out"),
Self::ResidMid(i) => write!(f, "blocks.{i}.hook_resid_mid"),
Self::MlpPre(i) => write!(f, "blocks.{i}.mlp.hook_pre"),
Self::MlpPost(i) => write!(f, "blocks.{i}.mlp.hook_post"),
Self::MlpOut(i) => write!(f, "blocks.{i}.hook_mlp_out"),
Self::ResidPost(i) => write!(f, "blocks.{i}.hook_resid_post"),
Self::FinalNorm => write!(f, "hook_final_norm"),
Self::RwkvState(i) => write!(f, "blocks.{i}.rwkv.hook_state"),
Self::RwkvDecay(i) => write!(f, "blocks.{i}.rwkv.hook_decay"),
Self::RwkvEffectiveAttn(i) => write!(f, "blocks.{i}.rwkv.hook_effective_attn"),
Self::Custom(s) => write!(f, "{s}"),
}
}
}
impl FromStr for HookPoint {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Ok(parse_hook_string(s))
}
}
impl From<&str> for HookPoint {
fn from(s: &str) -> Self {
parse_hook_string(s)
}
}
fn parse_hook_string(s: &str) -> HookPoint {
match s {
"hook_embed" => return HookPoint::Embed,
"hook_final_norm" => return HookPoint::FinalNorm,
_ => {}
}
if let Some(rest) = s.strip_prefix("blocks.")
&& let Some((layer_str, suffix)) = rest.split_once('.')
&& let Ok(layer) = layer_str.parse::<usize>()
{
return match suffix {
"hook_resid_pre" => HookPoint::ResidPre(layer),
"attn.hook_q" => HookPoint::AttnQ(layer),
"attn.hook_k" => HookPoint::AttnK(layer),
"attn.hook_v" => HookPoint::AttnV(layer),
"attn.hook_scores" => HookPoint::AttnScores(layer),
"attn.hook_pattern" => HookPoint::AttnPattern(layer),
"hook_attn_out" => HookPoint::AttnOut(layer),
"hook_resid_mid" => HookPoint::ResidMid(layer),
"mlp.hook_pre" => HookPoint::MlpPre(layer),
"mlp.hook_post" => HookPoint::MlpPost(layer),
"hook_mlp_out" => HookPoint::MlpOut(layer),
"hook_resid_post" => HookPoint::ResidPost(layer),
"rwkv.hook_state" => HookPoint::RwkvState(layer),
"rwkv.hook_decay" => HookPoint::RwkvDecay(layer),
"rwkv.hook_effective_attn" => HookPoint::RwkvEffectiveAttn(layer),
_ => HookPoint::Custom(s.to_string()),
};
}
HookPoint::Custom(s.to_string())
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum Intervention {
Replace(Tensor),
Add(Tensor),
Knockout(Tensor),
Scale(f64),
Zero,
}
#[cfg(any(feature = "transformer", feature = "rwkv"))]
pub(crate) fn apply_intervention(tensor: &Tensor, intervention: &Intervention) -> Result<Tensor> {
match intervention {
Intervention::Replace(replacement) => Ok(replacement.clone()),
Intervention::Add(delta) => {
let delta = if delta.dtype() == tensor.dtype() {
delta
} else {
&delta.to_dtype(tensor.dtype())?
};
Ok(tensor.broadcast_add(delta)?)
}
Intervention::Knockout(mask) => Ok(tensor.broadcast_add(mask)?),
Intervention::Scale(factor) => Ok((tensor * *factor)?),
Intervention::Zero => Ok(tensor.zeros_like()?),
}
}
#[derive(Debug, Clone, Default)]
pub struct HookSpec {
captures: HashSet<HookPoint>,
interventions: Vec<(HookPoint, Intervention)>,
state_knockout: Option<StateKnockoutSpec>,
state_steering: Option<StateSteeringSpec>,
}
impl HookSpec {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn capture<H: Into<HookPoint>>(&mut self, hook: H) -> &mut Self {
self.captures.insert(hook.into());
self
}
pub fn intervene<H: Into<HookPoint>>(
&mut self,
hook: H,
intervention: Intervention,
) -> &mut Self {
self.interventions.push((hook.into(), intervention));
self
}
#[must_use]
pub fn is_captured(&self, hook: &HookPoint) -> bool {
self.captures.contains(hook)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.captures.is_empty()
&& self.interventions.is_empty()
&& self.state_knockout.is_none()
&& self.state_steering.is_none()
}
#[must_use]
pub fn num_captures(&self) -> usize {
self.captures.len()
}
#[must_use]
pub const fn num_interventions(&self) -> usize {
self.interventions.len()
}
pub fn interventions_at(&self, hook: &HookPoint) -> impl Iterator<Item = &Intervention> {
self.interventions
.iter()
.filter(move |(h, _)| h == hook)
.map(|(_, intervention)| intervention)
}
#[must_use]
pub fn has_intervention_at(&self, hook: &HookPoint) -> bool {
self.interventions.iter().any(|(h, _)| h == hook)
}
pub fn set_state_knockout(&mut self, spec: StateKnockoutSpec) -> &mut Self {
self.state_knockout = Some(spec);
self
}
pub fn set_state_steering(&mut self, spec: StateSteeringSpec) -> &mut Self {
self.state_steering = Some(spec);
self
}
#[must_use]
pub const fn state_knockout(&self) -> Option<&StateKnockoutSpec> {
self.state_knockout.as_ref()
}
#[must_use]
pub const fn state_steering(&self) -> Option<&StateSteeringSpec> {
self.state_steering.as_ref()
}
pub fn extend(&mut self, other: &Self) -> &mut Self {
self.captures.extend(other.captures.iter().cloned());
self.interventions
.extend(other.interventions.iter().cloned());
self
}
}
#[derive(Debug)]
pub struct HookCache {
output: Tensor,
captures: HashMap<HookPoint, Tensor>,
}
impl HookCache {
#[must_use]
pub fn new(output: Tensor) -> Self {
Self {
output,
captures: HashMap::new(),
}
}
#[must_use]
pub const fn output(&self) -> &Tensor {
&self.output
}
#[must_use]
pub fn into_output(self) -> Tensor {
self.output
}
#[must_use]
pub fn get(&self, hook: &HookPoint) -> Option<&Tensor> {
self.captures.get(hook)
}
pub fn require(&self, hook: &HookPoint) -> Result<&Tensor> {
self.captures
.get(hook)
.ok_or_else(|| MIError::Hook(format!("hook point `{hook}` was not captured")))
}
pub fn store(&mut self, hook: HookPoint, tensor: Tensor) {
self.captures.insert(hook, tensor);
}
pub fn set_output(&mut self, output: Tensor) {
self.output = output;
}
#[must_use]
pub fn num_captures(&self) -> usize {
self.captures.len()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn hook_point_display_roundtrip() {
let cases: Vec<(HookPoint, &str)> = vec![
(HookPoint::Embed, "hook_embed"),
(HookPoint::FinalNorm, "hook_final_norm"),
(HookPoint::ResidPre(0), "blocks.0.hook_resid_pre"),
(HookPoint::AttnQ(3), "blocks.3.attn.hook_q"),
(HookPoint::AttnK(3), "blocks.3.attn.hook_k"),
(HookPoint::AttnV(3), "blocks.3.attn.hook_v"),
(HookPoint::AttnScores(7), "blocks.7.attn.hook_scores"),
(HookPoint::AttnPattern(5), "blocks.5.attn.hook_pattern"),
(HookPoint::AttnOut(2), "blocks.2.hook_attn_out"),
(HookPoint::ResidMid(11), "blocks.11.hook_resid_mid"),
(HookPoint::MlpPre(1), "blocks.1.mlp.hook_pre"),
(HookPoint::MlpPost(1), "blocks.1.mlp.hook_post"),
(HookPoint::MlpOut(4), "blocks.4.hook_mlp_out"),
(HookPoint::ResidPost(9), "blocks.9.hook_resid_post"),
(HookPoint::RwkvState(6), "blocks.6.rwkv.hook_state"),
(HookPoint::RwkvDecay(6), "blocks.6.rwkv.hook_decay"),
(
HookPoint::RwkvEffectiveAttn(6),
"blocks.6.rwkv.hook_effective_attn",
),
];
for (hook, expected_str) in cases {
assert_eq!(
hook.to_string(),
expected_str,
"Display failed for {hook:?}"
);
let parsed: HookPoint = expected_str.parse().unwrap();
assert_eq!(parsed, hook, "FromStr failed for {expected_str:?}");
let from_str: HookPoint = HookPoint::from(expected_str);
assert_eq!(from_str, hook, "From<&str> failed for {expected_str:?}");
}
}
#[test]
fn unknown_string_becomes_custom() {
let hook: HookPoint = "some.unknown.hook".parse().unwrap();
assert_eq!(hook, HookPoint::Custom("some.unknown.hook".to_string()));
}
#[test]
fn hook_spec_capture_and_query() {
let mut spec = HookSpec::new();
assert!(spec.is_empty());
spec.capture(HookPoint::AttnPattern(5));
spec.capture("blocks.3.hook_resid_post");
assert!(!spec.is_empty());
assert_eq!(spec.num_captures(), 2);
assert!(spec.is_captured(&HookPoint::AttnPattern(5)));
assert!(spec.is_captured(&HookPoint::ResidPost(3)));
assert!(!spec.is_captured(&HookPoint::Embed));
}
#[test]
fn hook_spec_intervention_query() {
let mut spec = HookSpec::new();
spec.intervene(HookPoint::AttnScores(5), Intervention::Zero);
spec.intervene(HookPoint::AttnScores(5), Intervention::Scale(2.0));
spec.intervene(HookPoint::ResidPost(10), Intervention::Zero);
assert_eq!(spec.num_interventions(), 3);
assert!(spec.has_intervention_at(&HookPoint::AttnScores(5)));
assert!(!spec.has_intervention_at(&HookPoint::Embed));
let at_5: Vec<_> = spec.interventions_at(&HookPoint::AttnScores(5)).collect();
assert_eq!(at_5.len(), 2);
}
}