use candle_core::{D, DType, Module, Tensor};
use candle_nn::{Linear, VarBuilder};
use crate::config::{QkvLayout, TransformerConfig};
use crate::error::Result;
use crate::hooks::{HookCache, HookPoint, HookSpec};
use super::rope::RopeCache;
enum QkvProj {
Separate {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
},
Fused {
qkv_proj: Linear,
q_dim: usize,
kv_dim: usize,
},
}
impl QkvProj {
fn forward(&self, x: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
match self {
Self::Separate {
q_proj,
k_proj,
v_proj,
} => {
let q = q_proj.forward(x)?;
let k = k_proj.forward(x)?;
let v = v_proj.forward(x)?;
Ok((q, k, v))
}
Self::Fused {
qkv_proj,
q_dim,
kv_dim,
} => {
let qkv = qkv_proj.forward(x)?;
let q = qkv.narrow(D::Minus1, 0, *q_dim)?;
let k = qkv.narrow(D::Minus1, *q_dim, *kv_dim)?;
let v = qkv.narrow(D::Minus1, q_dim + kv_dim, *kv_dim)?;
Ok((q, k, v))
}
}
}
}
pub struct Attention {
qkv: QkvProj,
o_proj: Linear,
num_attention_heads: usize,
num_kv_heads: usize,
head_dim: usize,
scale: f64,
attn_logit_softcapping: Option<f64>,
}
impl Attention {
#[allow(clippy::needless_pass_by_value)] pub fn load(config: &TransformerConfig, vb: VarBuilder<'_>) -> Result<Self> {
let q_dim = config.num_attention_heads * config.head_dim;
let kv_dim = config.num_kv_heads * config.head_dim;
let qkv = match config.qkv_layout {
QkvLayout::Separate => {
let q_proj = if config.qkv_bias {
candle_nn::linear(config.hidden_size, q_dim, vb.pp("q_proj"))?
} else {
candle_nn::linear_no_bias(config.hidden_size, q_dim, vb.pp("q_proj"))?
};
let k_proj = if config.qkv_bias {
candle_nn::linear(config.hidden_size, kv_dim, vb.pp("k_proj"))?
} else {
candle_nn::linear_no_bias(config.hidden_size, kv_dim, vb.pp("k_proj"))?
};
let v_proj = if config.qkv_bias {
candle_nn::linear(config.hidden_size, kv_dim, vb.pp("v_proj"))?
} else {
candle_nn::linear_no_bias(config.hidden_size, kv_dim, vb.pp("v_proj"))?
};
QkvProj::Separate {
q_proj,
k_proj,
v_proj,
}
}
QkvLayout::Fused => {
let total_dim = q_dim + 2 * kv_dim;
let qkv_proj = if config.qkv_bias {
candle_nn::linear(config.hidden_size, total_dim, vb.pp("qkv_proj"))?
} else {
candle_nn::linear_no_bias(config.hidden_size, total_dim, vb.pp("qkv_proj"))?
};
QkvProj::Fused {
qkv_proj,
q_dim,
kv_dim,
}
}
};
let o_proj = if config.o_proj_bias {
candle_nn::linear(q_dim, config.hidden_size, vb.pp("o_proj"))?
} else {
candle_nn::linear_no_bias(q_dim, config.hidden_size, vb.pp("o_proj"))?
};
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let scale = config.query_pre_attn_scalar.map_or_else(
|| 1.0 / (config.head_dim as f64).sqrt(),
|scalar| 1.0 / scalar.sqrt(),
);
Ok(Self {
qkv,
o_proj,
num_attention_heads: config.num_attention_heads,
num_kv_heads: config.num_kv_heads,
head_dim: config.head_dim,
scale,
attn_logit_softcapping: config.attn_logit_softcapping,
})
}
pub fn forward(
&self,
x: &Tensor,
mask: &Tensor,
rope: &RopeCache,
layer_idx: usize,
hooks: &HookSpec,
cache: &mut HookCache,
) -> Result<Tensor> {
let (batch, seq_len, _hidden) = x.dims3()?;
let (q, k, v) = self.qkv.forward(x)?;
let mut q = q
.reshape((batch, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?;
let mut k = k
.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let mut v = v
.reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
if hooks.is_captured(&HookPoint::AttnQ(layer_idx)) {
cache.store(HookPoint::AttnQ(layer_idx), q.clone());
}
for intervention in hooks.interventions_at(&HookPoint::AttnQ(layer_idx)) {
q = crate::hooks::apply_intervention(&q, intervention)?;
}
if hooks.is_captured(&HookPoint::AttnK(layer_idx)) {
cache.store(HookPoint::AttnK(layer_idx), k.clone());
}
for intervention in hooks.interventions_at(&HookPoint::AttnK(layer_idx)) {
k = crate::hooks::apply_intervention(&k, intervention)?;
}
if hooks.is_captured(&HookPoint::AttnV(layer_idx)) {
cache.store(HookPoint::AttnV(layer_idx), v.clone());
}
for intervention in hooks.interventions_at(&HookPoint::AttnV(layer_idx)) {
v = crate::hooks::apply_intervention(&v, intervention)?;
}
let q = rope.apply(&q, 0)?;
let k = rope.apply(&k, 0)?;
let k = repeat_kv(k, self.num_attention_heads, self.num_kv_heads)?;
let v = repeat_kv(v, self.num_attention_heads, self.num_kv_heads)?;
let k_t = k.contiguous()?.transpose(2, 3)?;
let q = q.contiguous()?;
let mut scores = q.matmul(&k_t)?;
scores = (scores * self.scale)?;
if hooks.is_captured(&HookPoint::AttnScores(layer_idx)) {
cache.store(HookPoint::AttnScores(layer_idx), scores.clone());
}
for intervention in hooks.interventions_at(&HookPoint::AttnScores(layer_idx)) {
scores = crate::hooks::apply_intervention(&scores, intervention)?;
}
if let Some(cap) = self.attn_logit_softcapping {
scores = ((scores / cap)?.tanh()? * cap)?;
}
scores = scores.broadcast_add(mask)?;
let original_dtype = scores.dtype();
let scores_f32 = if original_dtype == DType::F32 {
scores
} else {
scores.to_dtype(DType::F32)?
};
let mut pattern = candle_nn::ops::softmax_last_dim(&scores_f32)?;
if original_dtype != DType::F32 {
pattern = pattern.to_dtype(original_dtype)?;
}
if hooks.is_captured(&HookPoint::AttnPattern(layer_idx)) {
cache.store(HookPoint::AttnPattern(layer_idx), pattern.clone());
}
for intervention in hooks.interventions_at(&HookPoint::AttnPattern(layer_idx)) {
pattern = crate::hooks::apply_intervention(&pattern, intervention)?;
}
let v = v.contiguous()?;
let attn_output = pattern.matmul(&v)?;
let attn_output = attn_output.transpose(1, 2)?.contiguous()?.reshape((
batch,
seq_len,
self.num_attention_heads * self.head_dim,
))?;
Ok(self.o_proj.forward(&attn_output)?)
}
}
fn repeat_kv(x: Tensor, n_heads: usize, n_kv_heads: usize) -> Result<Tensor> {
if n_heads == n_kv_heads {
return Ok(x);
}
let repeats = n_heads / n_kv_heads;
let (batch, _kv_heads, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((batch, n_kv_heads, repeats, seq_len, head_dim))?
.reshape((batch, n_heads, seq_len, head_dim))?;
Ok(x)
}