use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
pub struct Gpt2ModelResources;
pub struct Gpt2ConfigResources;
pub struct Gpt2VocabResources;
pub struct Gpt2MergesResources;
impl Gpt2ModelResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/model.ot",
"https://cdn.huggingface.co/gpt2-rust_model.ot",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model.ot",
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model.ot",
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model.ot",
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model.ot",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/model.ot",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/rust_model.ot",
);
}
impl Gpt2ConfigResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/config.json",
"https://cdn.huggingface.co/gpt2-config.json",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config.json",
"https://cdn.huggingface.co/gpt2-medium-config.json",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config.json",
"https://cdn.huggingface.co/gpt2-large-config.json",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config.json",
"https://cdn.huggingface.co/gpt2-xl-config.json",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config.json",
"https://cdn.huggingface.co/distilgpt2-config.json",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/config.json",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/config.json",
);
}
impl Gpt2VocabResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab.txt",
"https://cdn.huggingface.co/gpt2-vocab.json",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab.txt",
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab.txt",
"https://cdn.huggingface.co/gpt2-large-vocab.json",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab.txt",
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab.txt",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/vocab.txt",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/vocab.json",
);
}
impl Gpt2MergesResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges.txt",
"https://cdn.huggingface.co/gpt2-merges.txt",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges.txt",
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges.txt",
"https://cdn.huggingface.co/gpt2-large-merges.txt",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges.txt",
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges.txt",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/merges.txt",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/merges.txt",
);
}
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
pub enum GptActivation {
gelu,
relu,
swish,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Gpt2Config {
pub attn_pdrop: Option<f64>,
pub embd_pdrop: Option<f64>,
pub hidden_dropout_prob: Option<f64>,
pub afn: Option<GptActivation>,
pub initializer_range: f64,
pub layer_norm_epsilon: f64,
pub n_ctx: i64,
pub n_embd: i64,
pub n_head: i64,
pub n_layer: i64,
pub n_positions: i64,
pub num_labels: Option<i64>,
pub output_past: Option<bool>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub resid_pdrop: Option<f64>,
pub vocab_size: i64,
}
impl Config<Gpt2Config> for Gpt2Config {}
pub struct Gpt2Model {
wte: nn::Embedding,
wpe: nn::Embedding,
drop: Dropout,
ln_f: nn::LayerNorm,
h: Vec<Block>,
output_past: bool,
output_hidden_states: bool,
output_attentions: bool,
}
impl Gpt2Model {
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
let p = &(p / "transformer");
let wte = embedding(
&(p / "wte"),
config.vocab_size,
config.n_embd,
Default::default(),
);
let wpe = embedding(
&(p / "wpe"),
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1,
};
let drop = Dropout::new(embd_pdrop);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
let mut h: Vec<Block> = vec![];
let h_path = &(p / "h");
for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true));
}
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
};
let output_past = match config.output_past {
Some(value) => value,
None => true,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false,
};
Gpt2Model {
wte,
wpe,
drop,
ln_f,
h,
output_past,
output_hidden_states,
output_attentions,
}
}
pub fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply(&self.wte),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let (layer_past, layer_past_length) = match layer_past {
Some(value) => {
assert_eq!(
value.len(),
self.h.len(),
"Past activations vector must be of length equal to the number of layers"
);
(
value
.iter()
.map(|v| Some(v.copy()))
.collect::<Vec<Option<Tensor>>>(),
value[0].size()[3],
)
}
None => {
let mut out = Vec::with_capacity(self.h.len());
out.resize_with(self.h.len(), || None::<Tensor>);
(out, 0)
}
};
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange1(
layer_past_length,
seq_length + layer_past_length,
(Int64, input_embeddings.device()),
)
.unsqueeze(0),
};
let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0)
* 10000.0,
),
None => None,
};
let position_embeds = position_ids.apply(&self.wpe);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.wte),
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> =
if self.output_past { Some(vec![]) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut layer_iter = self.h.iter().zip(layer_past);
loop {
match layer_iter.next() {
Some(layer_values) => {
let (layer, past) = layer_values;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &past, &attention_mask, train);
hidden_state = temp.0;
if let Some(presents) = all_presents.borrow_mut() {
presents.push(temp.1.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
None => break,
};
}
Ok((
hidden_state.apply(&self.ln_f),
all_presents,
all_hidden_states,
all_attentions,
))
}
}
pub struct GPT2LMHeadModel {
transformer: Gpt2Model,
lm_head: LinearNoBias,
}
impl GPT2LMHeadModel {
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
let transformer = Gpt2Model::new(&p, config);
let lm_head = linear_no_bias(
&(p / "lm_head"),
config.n_embd,
config.vocab_size,
Default::default(),
);
GPT2LMHeadModel {
transformer,
lm_head,
}
}
}
impl LMHeadModel for GPT2LMHeadModel {
fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (output, past, all_hidden_states, all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
input_ids,
&layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?),
Cache::None => Ok(self.transformer.forward_t(
input_ids,
&None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?),
_ => Err("Cache not compatible with GPT2 model"),
}?;
let lm_logits = output.apply(&self.lm_head);
Ok((
lm_logits,
None,
Cache::GPT2Cache(past),
all_hidden_states,
all_attentions,
))
}
}