use candle_core::{Result, Tensor, DType, Device};
use candle_nn::VarBuilder;
use super::activations::CastedLinear;
use super::positional::{apply_rotary_pos_emb, RotaryEmbedding};
pub struct Attention {
hidden_size: usize,
head_dim: usize,
output_size: usize,
num_heads: usize,
num_key_value_heads: usize,
causal: bool,
qkv_proj: CastedLinear,
o_proj: CastedLinear,
}
impl Attention {
pub fn new(
hidden_size: usize,
head_dim: usize,
num_heads: usize,
num_key_value_heads: usize,
causal: bool,
vb: VarBuilder,
) -> Result<Self> {
let output_size = head_dim * num_heads;
let qkv_size = (num_heads + 2 * num_key_value_heads) * head_dim;
let qkv_proj = CastedLinear::new(
hidden_size,
qkv_size,
false,
vb.pp("qkv_proj"),
)?;
let o_proj = CastedLinear::new(
output_size,
hidden_size,
false,
vb.pp("o_proj"),
)?;
Ok(Self {
hidden_size,
head_dim,
output_size,
num_heads,
num_key_value_heads,
causal,
qkv_proj,
o_proj,
})
}
pub fn forward(
&self,
hidden_states: &Tensor,
cos_sin: Option<(&Tensor, &Tensor)>,
) -> Result<Tensor> {
let (batch_size, seq_len, _) = hidden_states.dims3()?;
let qkv = self.qkv_proj.forward(hidden_states)?;
let qkv = qkv.reshape((
batch_size,
seq_len,
self.num_heads + 2 * self.num_key_value_heads,
self.head_dim,
))?;
let query = qkv.narrow(2, 0, self.num_heads)?; let key = qkv.narrow(2, self.num_heads, self.num_key_value_heads)?;
let value = qkv.narrow(2, self.num_heads + self.num_key_value_heads, self.num_key_value_heads)?;
let (query, key) = if let Some((cos, sin)) = cos_sin {
apply_rotary_pos_emb(&query, &key, cos, sin)?
} else {
(query, key)
};
let query = query.transpose(1, 2)?.contiguous()?;
let key = key.transpose(1, 2)?.contiguous()?;
let value = value.transpose(1, 2)?.contiguous()?;
let (key, value) = if self.num_key_value_heads < self.num_heads {
let repeat_factor = self.num_heads / self.num_key_value_heads;
(
repeat_kv(&key, repeat_factor)?,
repeat_kv(&value, repeat_factor)?,
)
} else {
(key, value)
};
let attn_output = scaled_dot_product_attention(
&query,
&key,
&value,
self.causal,
)?;
let attn_output = attn_output.transpose(1, 2)?;
let attn_output = attn_output.reshape((batch_size, seq_len, self.output_size))?;
self.o_proj.forward(&attn_output)
}
}
fn repeat_kv(x: &Tensor, n: usize) -> Result<Tensor> {
if n == 1 {
return Ok(x.clone());
}
let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?;
let x = x.unsqueeze(2)?;
let x = x.broadcast_as((batch, num_kv_heads, n, seq_len, head_dim))?;
x.reshape((batch, num_kv_heads * n, seq_len, head_dim))
}
fn scaled_dot_product_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
causal: bool,
) -> Result<Tensor> {
let (_batch, _num_heads, seq_len, head_dim) = query.dims4()?;
let scale = 1.0 / (head_dim as f64).sqrt();
let scores = query.matmul(&key.transpose(2, 3)?)?;
let scores = (scores * scale)?;
let scores = if causal {
let mask = create_causal_mask(seq_len, scores.device())?;
scores.broadcast_add(&mask)?
} else {
scores
};
let attn_weights = candle_nn::ops::softmax_last_dim(&scores)?;
attn_weights.matmul(value)
}
fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
mask_data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
Tensor::from_vec(mask_data, (seq_len, seq_len), device)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_nn::VarMap;
#[test]
fn test_attention_shape() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let attn = Attention::new(256, 32, 8, 8, false, vb)?;
let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let out = attn.forward(&x, None)?;
assert_eq!(out.dims(), &[2, 16, 256]);
Ok(())
}
#[test]
fn test_attention_with_rope() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let attn = Attention::new(256, 32, 8, 8, false, vb)?;
let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let rope = RotaryEmbedding::new(32, 512, 10000.0, &device)?;
let (cos, sin) = rope.forward_with_len(16)?;
let out = attn.forward(&x, Some((&cos, &sin)))?;
assert_eq!(out.dims(), &[2, 16, 256]);
Ok(())
}
#[test]
fn test_grouped_query_attention() -> Result<()> {
let device = Device::Cpu;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let attn = Attention::new(256, 32, 8, 2, false, vb)?;
let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
let out = attn.forward(&x, None)?;
assert_eq!(out.dims(), &[2, 16, 256]);
Ok(())
}
#[test]
fn test_causal_mask() -> Result<()> {
let device = Device::Cpu;
let mask = create_causal_mask(4, &device)?;
assert_eq!(mask.dims(), &[4, 4]);
let mask_vec = mask.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(mask_vec[0], 0.0);
assert!(mask_vec[1].is_infinite() && mask_vec[1].is_sign_negative());
assert_eq!(mask_vec[4], 0.0);
assert_eq!(mask_vec[5], 0.0);
assert!(mask_vec[6].is_infinite() && mask_vec[6].is_sign_negative());
Ok(())
}
#[test]
fn test_repeat_kv() -> Result<()> {
let device = Device::Cpu;
let x = Tensor::randn(0f32, 1.0, (2, 2, 16, 32), &device)?;
let repeated = repeat_kv(&x, 4)?;
assert_eq!(repeated.dims(), &[2, 8, 16, 32]);
Ok(())
}
}