use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::softmax;
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub to_qkv: Linear<B>,
pub to_out: Linear<B>,
pub n_heads: usize,
pub dim_head: usize,
}
impl<B: Backend> Attention<B> {
pub fn new(
input_dim: usize,
output_dim: usize,
heads: usize,
dim_head: usize,
qkv_bias: bool,
device: &B::Device,
) -> Self {
let inner_dim = dim_head * heads;
Self {
to_qkv: LinearConfig::new(input_dim, inner_dim * 3)
.with_bias(qkv_bias)
.init(device),
to_out: LinearConfig::new(inner_dim, output_dim)
.with_bias(true)
.init(device),
n_heads: heads,
dim_head,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, n, _] = x.dims();
let (h, dh) = (self.n_heads, self.dim_head);
let inner_dim = h * dh;
let qkv = self.to_qkv.forward(x);
let q = qkv.clone().narrow(2, 0, inner_dim);
let k = qkv.clone().narrow(2, inner_dim, inner_dim);
let v = qkv.narrow(2, inner_dim * 2, inner_dim);
let q = q.reshape([b, n, h, dh]).swap_dims(1, 2);
let k = k.reshape([b, n, h, dh]).swap_dims(1, 2);
let v = v.reshape([b, n, h, dh]).swap_dims(1, 2);
let scale = (dh as f64).powf(-0.5) as f32;
let attn = softmax(q.matmul(k.transpose()).mul_scalar(scale), 3);
let out = attn.matmul(v);
let out = out.swap_dims(1, 2).reshape([b, n, inner_dim]);
self.to_out.forward(out)
}
}