use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::softmax;
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub to_qkv: Linear<B>,
pub to_out: Linear<B>,
pub heads: usize,
pub dim_head: usize,
pub scale: f32,
}
impl<B: Backend> Attention<B> {
pub fn new(dim: usize, heads: usize, device: &B::Device) -> Self {
let dim_head = dim / heads;
Self {
to_qkv: LinearConfig::new(dim, 3 * dim).with_bias(false).init(device),
to_out: LinearConfig::new(dim, dim).with_bias(false).init(device),
heads,
dim_head,
scale: (dim_head as f32).powf(-0.5),
}
}
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
impl<B: Backend> Attention<B> {
pub fn forward(
&self,
x: Tensor<B, 3>,
rotary_cos: Option<&Tensor<B, 2>>,
rotary_sin: Option<&Tensor<B, 2>>,
) -> Tensor<B, 3> {
attention_forward(self, x, rotary_cos, rotary_sin)
}
}
#[cfg(feature = "wgpu-kernels-metal")]
impl<B: Backend + crate::model_burn::FusedOps> Attention<B> {
pub fn forward(
&self,
x: Tensor<B, 3>,
rotary_cos: Option<&Tensor<B, 2>>,
rotary_sin: Option<&Tensor<B, 2>>,
) -> Tensor<B, 3> {
attention_forward(self, x, rotary_cos, rotary_sin)
}
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
fn attention_forward<B: Backend>(
this: &Attention<B>,
x: Tensor<B, 3>,
rotary_cos: Option<&Tensor<B, 2>>,
rotary_sin: Option<&Tensor<B, 2>>,
) -> Tensor<B, 3> {
_attention_forward_body(this, x, rotary_cos, rotary_sin)
}
#[cfg(feature = "wgpu-kernels-metal")]
fn attention_forward<B: Backend + crate::model_burn::FusedOps>(
this: &Attention<B>,
x: Tensor<B, 3>,
rotary_cos: Option<&Tensor<B, 2>>,
rotary_sin: Option<&Tensor<B, 2>>,
) -> Tensor<B, 3> {
_attention_forward_body(this, x, rotary_cos, rotary_sin)
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
fn _attention_forward_body<B: Backend>(
this: &Attention<B>,
x: Tensor<B, 3>,
rotary_cos: Option<&Tensor<B, 2>>,
rotary_sin: Option<&Tensor<B, 2>>,
) -> Tensor<B, 3> {
let [b, n, _d] = x.dims();
let (h, dh) = (this.heads, this.dim_head);
let dim = h * dh;
let qkv = this.to_qkv.forward(x);
let q = qkv.clone().narrow(2, 0, dim)
.reshape([b, n, h, dh]).swap_dims(1, 2);
let k = qkv.clone().narrow(2, dim, dim)
.reshape([b, n, h, dh]).swap_dims(1, 2);
let v = qkv .narrow(2, 2*dim, dim)
.reshape([b, n, h, dh]).swap_dims(1, 2);
let (q, k) = match (rotary_cos, rotary_sin) {
(Some(cos), Some(sin)) => (
apply_rotary_precomputed(q, cos, sin),
apply_rotary_precomputed(k, cos, sin),
),
_ => (q, k),
};
let q = q.mul_scalar(this.scale);
let attn = softmax(q.matmul(k.swap_dims(2, 3)), 3);
let out = attn.matmul(v);
let out = out.swap_dims(1, 2).flatten(2, 3);
this.to_out.forward(out)
}
#[cfg(feature = "wgpu-kernels-metal")]
fn _attention_forward_body<B: Backend + crate::model_burn::FusedOps>(
this: &Attention<B>,
x: Tensor<B, 3>,
rotary_cos: Option<&Tensor<B, 2>>,
rotary_sin: Option<&Tensor<B, 2>>,
) -> Tensor<B, 3> {
let [b, n, _d] = x.dims();
let (h, dh) = (this.heads, this.dim_head);
let dim = h * dh;
let qkv = this.to_qkv.forward(x);
let q = qkv.clone().narrow(2, 0, dim)
.reshape([b, n, h, dh]).swap_dims(1, 2);
let k = qkv.clone().narrow(2, dim, dim)
.reshape([b, n, h, dh]).swap_dims(1, 2);
let v = qkv .narrow(2, 2*dim, dim)
.reshape([b, n, h, dh]).swap_dims(1, 2);
let (q, k) = match (rotary_cos, rotary_sin) {
(Some(cos), Some(sin)) => (
rotate_qk::<B>(q, cos, sin),
rotate_qk::<B>(k, cos, sin),
),
_ => (q, k),
};
let q = q.mul_scalar(this.scale);
let attn = softmax(q.matmul(k.swap_dims(2, 3)), 3);
let out = attn.matmul(v);
let out = out.swap_dims(1, 2).flatten(2, 3);
this.to_out.forward(out)
}
#[cfg(not(feature = "wgpu-kernels-metal"))]
fn apply_rotary_precomputed<B: Backend>(
x: Tensor<B, 4>,
cos: &Tensor<B, 2>,
sin: &Tensor<B, 2>,
) -> Tensor<B, 4> {
let [b, h, n, d] = x.dims();
let half = cos.dims()[1];
let rot_dim = half * 2;
let x_rot = x.clone().slice([0..b, 0..h, 0..n, 0..rot_dim]);
let x1 = x_rot.clone().slice([0..b, 0..h, 0..n, 0..half]);
let x2 = x_rot .slice([0..b, 0..h, 0..n, half..rot_dim]);
let c = cos.clone().unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0);
let s = sin.clone().unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0);
let r1 = x1.clone() * c.clone() - x2.clone() * s.clone();
let r2 = x2 * c + x1 * s;
let rotated = Tensor::cat(vec![r1, r2], 3);
if rot_dim < d {
let x_pass = x.slice([0..b, 0..h, 0..n, rot_dim..d]);
Tensor::cat(vec![rotated, x_pass], 3)
} else {
rotated
}
}
#[cfg(feature = "wgpu-kernels-metal")]
fn rotate_qk<B: Backend + crate::model_burn::FusedOps>(
x: Tensor<B, 4>,
cos: &Tensor<B, 2>,
sin: &Tensor<B, 2>,
) -> Tensor<B, 4> {
B::rope_rotate(x, cos.clone(), sin.clone())
}