use crate::gpt2::attention::Attention;
use crate::gpt2::transformer::MLP;
use crate::gpt2::Gpt2Config;
use tch::{nn, Tensor};
pub struct Block {
ln_1: nn::LayerNorm,
attn: Attention,
ln_2: nn::LayerNorm,
mlp: MLP,
}
impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config);
Block {
ln_1,
attn,
ln_2,
mlp,
}
}
pub fn forward_t(
&self,
x: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
let x = (x + output).apply(&self.ln_1);
let m = self.mlp.forward_t(&x, train);
let x = (x + m).apply(&self.ln_2);
(x, attentions)
}
}