use super::attention_helpers::{
divide_with_eps, elu_feature_map, matmul_with_broadcast, repeat_kv_heads, sum_last_dim,
};
use super::positional_encoding::{
matmul_batched, reshape_for_attention, reshape_from_attention, transpose_last_two,
};
#[allow(clippy::wildcard_imports)]
use super::*;
pub(super) fn slice_pe(pe: &Tensor, seq_len: usize, d_model: usize) -> Tensor {
let data: Vec<f32> = pe.data()[..seq_len * d_model].to_vec();
Tensor::new(&data, &[seq_len, d_model])
}
pub(super) fn add_positional_encoding(x: &Tensor, pe: &Tensor) -> Tensor {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let d_model = x.shape()[2];
let mut output = vec![0.0; x.data().len()];
for b in 0..batch_size {
for s in 0..seq_len {
for d in 0..d_model {
let x_idx = b * seq_len * d_model + s * d_model + d;
let pe_idx = s * d_model + d;
output[x_idx] = x.data()[x_idx] + pe.data()[pe_idx];
}
}
}
Tensor::new(&output, x.shape())
}
#[must_use]
pub fn generate_causal_mask(size: usize) -> Tensor {
let mut data = vec![0.0; size * size];
for i in 0..size {
for j in 0..size {
if j > i {
data[i * size + j] = f32::NEG_INFINITY;
}
}
}
Tensor::new(&data, &[size, size])
}
pub struct LinearAttention {
embed_dim: usize,
num_heads: usize,
head_dim: usize,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
eps: f32,
training: bool,
}
impl LinearAttention {
#[must_use]
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
assert!(
embed_dim.is_multiple_of(num_heads),
"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
);
let head_dim = embed_dim / num_heads;
Self {
embed_dim,
num_heads,
head_dim,
q_proj: Linear::new(embed_dim, embed_dim),
k_proj: Linear::new(embed_dim, embed_dim),
v_proj: Linear::new(embed_dim, embed_dim),
out_proj: Linear::new(embed_dim, embed_dim),
eps: 1e-6,
training: true,
}
}
#[must_use]
pub fn forward_linear(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Tensor {
let batch_size = query.shape()[0];
let tgt_len = query.shape()[1];
let src_len = key.shape()[1];
let q = self.q_proj.forward(query);
let k = self.k_proj.forward(key);
let v = self.v_proj.forward(value);
let q = reshape_for_attention(&q, batch_size, tgt_len, self.num_heads, self.head_dim);
let k = reshape_for_attention(&k, batch_size, src_len, self.num_heads, self.head_dim);
let v = reshape_for_attention(&v, batch_size, src_len, self.num_heads, self.head_dim);
let q_prime = elu_feature_map(&q);
let k_prime = elu_feature_map(&k);
let k_prime_t = transpose_last_two(&k_prime);
let kv = matmul_batched(&k_prime_t, &v);
let output = matmul_batched(&q_prime, &kv);
let k_sum = sum_last_dim(&k_prime); let normalizer = matmul_with_broadcast(&q_prime, &k_sum);
let output = divide_with_eps(&output, &normalizer, self.eps);
let output = reshape_from_attention(&output, batch_size, tgt_len, self.embed_dim);
self.out_proj.forward(&output)
}
#[must_use]
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
}
impl Module for LinearAttention {
fn forward(&self, input: &Tensor) -> Tensor {
self.forward_linear(input, input, input)
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.q_proj.parameters();
params.extend(self.k_proj.parameters());
params.extend(self.v_proj.parameters());
params.extend(self.out_proj.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.q_proj.parameters_mut();
params.extend(self.k_proj.parameters_mut());
params.extend(self.v_proj.parameters_mut());
params.extend(self.out_proj.parameters_mut());
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for LinearAttention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LinearAttention")
.field("embed_dim", &self.embed_dim)
.field("num_heads", &self.num_heads)
.field("head_dim", &self.head_dim)
.finish_non_exhaustive()
}
}
pub struct GroupedQueryAttention {
pub(crate) embed_dim: usize,
pub(crate) num_heads: usize,
pub(crate) num_kv_heads: usize,
pub(crate) head_dim: usize,
pub(crate) kv_head_dim: usize,
pub(crate) dropout_p: f32,
pub(crate) q_proj: Linear,
pub(crate) k_proj: Linear,
pub(crate) v_proj: Linear,
pub(crate) out_proj: Linear,
pub(crate) training: bool,
}
impl GroupedQueryAttention {
#[must_use]
pub fn new(embed_dim: usize, num_heads: usize, num_kv_heads: usize) -> Self {
assert!(
embed_dim.is_multiple_of(num_heads),
"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
);
assert!(
num_heads.is_multiple_of(num_kv_heads),
"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
);
let head_dim = embed_dim / num_heads;
let kv_dim = num_kv_heads * head_dim;
Self {
embed_dim,
num_heads,
num_kv_heads,
head_dim,
kv_head_dim: head_dim,
dropout_p: 0.0,
q_proj: Linear::new(embed_dim, embed_dim),
k_proj: Linear::new(embed_dim, kv_dim),
v_proj: Linear::new(embed_dim, kv_dim),
out_proj: Linear::new(embed_dim, embed_dim),
training: true,
}
}
#[must_use]
pub fn placeholder(embed_dim: usize, num_heads: usize, num_kv_heads: usize) -> Self {
let head_dim = embed_dim / num_heads;
let kv_dim = num_kv_heads * head_dim;
Self {
embed_dim,
num_heads,
num_kv_heads,
head_dim,
kv_head_dim: head_dim,
dropout_p: 0.0,
q_proj: Linear::placeholder(embed_dim, embed_dim),
k_proj: Linear::placeholder(embed_dim, kv_dim),
v_proj: Linear::placeholder(embed_dim, kv_dim),
out_proj: Linear::placeholder(embed_dim, embed_dim),
training: true,
}
}
#[must_use]
pub fn with_dropout(mut self, dropout_p: f32) -> Self {
self.dropout_p = dropout_p;
self
}
#[must_use]
pub fn forward_qkv(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
) -> (Tensor, Tensor) {
let batch_size = query.shape()[0];
let tgt_len = query.shape()[1];
let src_len = key.shape()[1];
let q = self.q_proj.forward(query);
let k = self.k_proj.forward(key);
let v = self.v_proj.forward(value);
let q = reshape_for_attention(&q, batch_size, tgt_len, self.num_heads, self.head_dim);
let k = reshape_for_attention(&k, batch_size, src_len, self.num_kv_heads, self.kv_head_dim);
let v = reshape_for_attention(&v, batch_size, src_len, self.num_kv_heads, self.kv_head_dim);
let groups = self.num_heads / self.num_kv_heads;
let k = repeat_kv_heads(&k, groups);
let v = repeat_kv_heads(&v, groups);
let (attn_output, attn_weights) =
scaled_dot_product_attention(&q, &k, &v, attn_mask, self.dropout_p, self.training);
let attn_output = reshape_from_attention(&attn_output, batch_size, tgt_len, self.embed_dim);
let output = self.out_proj.forward(&attn_output);
(output, attn_weights)
}
#[must_use]
pub fn forward_self(&self, x: &Tensor, attn_mask: Option<&Tensor>) -> (Tensor, Tensor) {
self.forward_qkv(x, x, x, attn_mask)
}
#[must_use]
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
#[must_use]
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
pub fn q_proj_mut(&mut self) -> &mut Linear {
&mut self.q_proj
}
pub fn k_proj_mut(&mut self) -> &mut Linear {
&mut self.k_proj
}
pub fn v_proj_mut(&mut self) -> &mut Linear {
&mut self.v_proj
}
pub fn out_proj_mut(&mut self) -> &mut Linear {
&mut self.out_proj
}
}
impl Module for GroupedQueryAttention {
fn forward(&self, input: &Tensor) -> Tensor {
let (output, _) = self.forward_self(input, None);
output
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.q_proj.parameters();
params.extend(self.k_proj.parameters());
params.extend(self.v_proj.parameters());
params.extend(self.out_proj.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.q_proj.parameters_mut();
params.extend(self.k_proj.parameters_mut());
params.extend(self.v_proj.parameters_mut());
params.extend(self.out_proj.parameters_mut());
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}