use tch::{Tensor, nn};
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::Gpt2Config;
use tch::kind::Kind::Float;
use tch::nn::{Init, Module};
#[derive(Debug)]
pub struct GPTConv1D {
weight: Tensor,
bias: Tensor,
}
impl GPTConv1D {
pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D {
let weight = p.var("weight", &[nx, nf], Init::Randn { mean: 0., stdev: 0.02 });
let bias = p.var("bias", &[nf], Init::Const(0.));
GPTConv1D { weight, bias }
}
}
impl Module for GPTConv1D {
fn forward(&self, xs: &Tensor) -> Tensor {
xs.matmul(&self.weight) + &self.bias
}
}
pub struct Attention {
bias: Tensor,
c_attn: GPTConv1D,
c_proj: GPTConv1D,
attn_dropout: Dropout,
resid_dropout: Dropout,
output_attentions: bool,
n_state: i64,
dim_per_head: i64,
n_head: i64,
scale: bool,
}
impl Attention {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Attention {
let bias = Tensor::ones(&[config.n_ctx, config.n_ctx], (Float, p.device()))
.tril(0)
.view((1, 1, config.n_ctx, config.n_ctx));
let c_attn = GPTConv1D::new(&(p / "c_attn"), config.n_embd * 3, config.n_embd);
let c_proj = GPTConv1D::new(&(p / "c_proj"), config.n_embd, config.n_embd);
let attn_pdrop = match config.attn_pdrop {
Some(value) => value,
None => 0.1
};
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
};
let attn_dropout = Dropout::new(attn_pdrop);
let resid_dropout = Dropout::new(resid_pdrop);
assert_eq!(config.n_embd % config.n_head, 0, "Attention hidden states not a multiple of the number of heads");
let dim_per_head = config.n_embd / config.n_head;
Attention {
bias,
c_attn,
c_proj,
attn_dropout,
resid_dropout,
output_attentions,
n_state: config.n_embd,
dim_per_head,
n_head: config.n_head,
scale,
}
}
fn split_heads(&self, x: &Tensor, k: bool) -> Tensor {
let x = x.view((x.size()[0], -1, self.n_head, self.dim_per_head));
if k {
x.permute(&[0, 2, 3, 1])
} else {
x.permute(&[0, 2, 1, 3])
}
}
fn flatten(&self, x: Tensor) -> Tensor {
x.transpose(1, 2).contiguous().view((x.size()[0], -1, &self.n_head * self.dim_per_head))
}
fn attention(&self, q: &Tensor, k: &Tensor, v: &Tensor, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Option<Tensor>) {
let mut w = q.matmul(&k);
if self.scale { w = w / (*v.size().last().unwrap() as f64).sqrt(); }
let (nd, ns) = (w.size()[2], w.size()[3]);
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask { w = w + mask; }
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&v);
if self.output_attentions {
(output, Some(w))
} else {
(output, None)
}
}
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Tensor, Option<Tensor>) {
let x = x.apply(&self.c_attn).split(self.n_state, 2);
let (query, key, value) =
(
self.split_heads(&x[0], false),
self.split_heads(&x[1], true),
self.split_heads(&x[2], false)
);
let (key, value) = match layer_past {
Some(past) => {
let key = Tensor::cat(&[past.get(0).transpose(-2, -1), key], -1);
let value = Tensor::cat(&[past.get(1), value], -2);
(key, value)
}
None => (key, value)
};
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
let a = self.flatten(a).apply(&self.c_proj).apply_t(&self.resid_dropout, train);
(a, present, attentions)
}
}