use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::gpt2::transformer::Block;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Device, Kind, Tensor};
pub struct Gpt2ModelResources;
pub struct Gpt2ConfigResources;
pub struct Gpt2VocabResources;
pub struct Gpt2MergesResources;
impl Gpt2ModelResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/model",
"https://huggingface.co/gpt2/resolve/main/rust_model.ot",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model",
"https://huggingface.co/gpt2-medium/resolve/main/rust_model.ot",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model",
"https://huggingface.co/gpt2-large/resolve/main/rust_model.ot",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model",
"https://huggingface.co/gpt2-xl/resolve/main/rust_model.ot",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model",
"https://huggingface.co/distilgpt2/resolve/main/rust_model.ot",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/model",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/rust_model.ot",
);
}
impl Gpt2ConfigResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/config",
"https://huggingface.co/gpt2/resolve/main/config.json",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config",
"https://huggingface.co/gpt2-medium/resolve/main/config.json",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config",
"https://huggingface.co/gpt2-large/resolve/main/config.json",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config",
"https://huggingface.co/gpt2-xl/resolve/main/config.json",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config",
"https://huggingface.co/distilgpt2/resolve/main/config.json",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/config",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/config.json",
);
}
impl Gpt2VocabResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab",
"https://huggingface.co/gpt2/resolve/main/vocab.json",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab",
"https://huggingface.co/gpt2-medium/resolve/main/vocab.json",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab",
"https://huggingface.co/gpt2-large/resolve/main/vocab.json",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab",
"https://huggingface.co/gpt2-xl/resolve/main/vocab.json",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab",
"https://huggingface.co/distilgpt2/resolve/main/vocab.json",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/vocab",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/vocab.json",
);
}
impl Gpt2MergesResources {
pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges",
"https://huggingface.co/gpt2/resolve/main/merges.txt",
);
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges",
"https://huggingface.co/gpt2-medium/resolve/main/merges.txt",
);
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges",
"https://huggingface.co/gpt2-large/resolve/main/merges.txt",
);
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges",
"https://huggingface.co/gpt2-xl/resolve/main/merges.txt",
);
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges",
"https://huggingface.co/distilgpt2/resolve/main/merges.txt",
);
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/merges",
"https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/merges.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Gpt2Config {
pub attn_pdrop: Option<f64>,
pub embd_pdrop: Option<f64>,
pub hidden_dropout_prob: Option<f64>,
pub afn: Option<Activation>,
pub initializer_range: f64,
pub layer_norm_epsilon: f64,
pub n_ctx: i64,
pub n_embd: i64,
pub n_head: i64,
pub n_layer: i64,
pub n_positions: i64,
pub num_labels: Option<i64>,
pub output_past: Option<bool>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub resid_pdrop: Option<f64>,
pub vocab_size: i64,
pub decoder_start_token_id: Option<i64>,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
}
impl Config for Gpt2Config {}
impl Default for Gpt2Config {
fn default() -> Self {
Gpt2Config {
attn_pdrop: Some(0.1),
embd_pdrop: Some(0.1),
hidden_dropout_prob: None,
afn: Some(Activation::gelu_new),
initializer_range: 0.02,
layer_norm_epsilon: 1e-5,
n_ctx: 1024,
n_embd: 768,
n_head: 12,
n_layer: 12,
n_positions: 0,
num_labels: None,
output_past: None,
output_attentions: None,
output_hidden_states: None,
resid_pdrop: Some(0.1),
vocab_size: 50257,
decoder_start_token_id: None,
forced_bos_token_id: None,
forced_eos_token_id: None,
}
}
}
pub struct Gpt2Model {
wte: nn::Embedding,
wpe: nn::Embedding,
drop: Dropout,
ln_f: nn::LayerNorm,
h: Vec<Block>,
output_past: bool,
output_hidden_states: bool,
output_attentions: bool,
}
impl Gpt2Model {
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> Gpt2Model
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "transformer";
let wte = embedding(
&p / "wte",
config.vocab_size,
config.n_embd,
Default::default(),
);
let wpe = embedding(
&p / "wpe",
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
let drop = Dropout::new(embd_pdrop);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
let mut h: Vec<Block> = vec![];
let h_path = &p / "h";
for layer_index in 0..config.n_layer {
h.push(Block::new(&h_path / layer_index, config, true));
}
let output_attentions = config.output_attentions.unwrap_or(false);
let output_past = config.output_past.unwrap_or(true);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
Gpt2Model {
wte,
wpe,
drop,
ln_f,
h,
output_past,
output_hidden_states,
output_attentions,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Option<&Vec<Tensor>>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Gpt2ModelOutput, RustBertError> {
let (calc_input_embeddings, input_size, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.wte)?;
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_size[1];
let (layer_past, layer_past_length) = match layer_past {
Some(value) => {
assert_eq!(
value.len(),
self.h.len(),
"Past activations vector must be of length equal to the number of layers"
);
(
value
.iter()
.map(|v| Some(v.copy()))
.collect::<Vec<Option<Tensor>>>(),
value[0].size()[3],
)
}
None => {
let mut out = Vec::with_capacity(self.h.len());
out.resize_with(self.h.len(), || None::<Tensor>);
(out, 0)
}
};
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange_start(
layer_past_length,
seq_length + layer_past_length,
(Int64, input_embeddings.device()),
)
.unsqueeze(0),
};
let attention_mask: Option<Tensor> = attention_mask.map(|value| {
let attention_mask = value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
.to_kind(input_embeddings.kind());
let attention_mask: Tensor = (1.0 - attention_mask) * (-10000.0);
attention_mask.to_kind(input_embeddings.kind())
});
let position_embeds = position_ids.apply(&self.wpe);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.wte),
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> =
if self.output_past { Some(vec![]) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let layer_iter = self.h.iter().zip(layer_past);
for layer_values in layer_iter {
let (layer, past) = layer_values;
let temp =
layer.forward_t(&hidden_state, past.as_ref(), attention_mask.as_ref(), train);
hidden_state = temp.0;
if let Some(presents) = all_presents.borrow_mut() {
presents.push(temp.1);
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.unwrap());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
}
Ok(Gpt2ModelOutput {
output: hidden_state.apply(&self.ln_f),
cache: all_presents,
all_hidden_states,
all_attentions,
})
}
}
pub struct GPT2LMHeadModel {
transformer: Gpt2Model,
}
impl GPT2LMHeadModel {
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> GPT2LMHeadModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transformer = Gpt2Model::new(p, config);
GPT2LMHeadModel { transformer }
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Option<&Vec<Tensor>>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = self.transformer.forward_t(
input_ids,
layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let lm_logits = base_model_output
.output
.linear::<Tensor>(&self.transformer.wte.ws, None);
Ok(LMModelOutput {
lm_logits,
cache: Cache::GPT2Cache(base_model_output.cache),
})
}
}
pub struct Gpt2ModelOutput {
pub output: Tensor,
pub cache: Option<Vec<Tensor>>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct GPT2Generator {
model: GPT2LMHeadModel,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl GPT2Generator {
pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT2 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::GPT2,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
false,
None,
None,
)?;
Self::new_with_tokenizer(generate_config, tokenizer)
}
pub fn new_with_tokenizer(
generate_config: GenerateConfig,
tokenizer: TokenizerOption,
) -> Result<GPT2Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(var_store.root(), &config);
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
let pad_token_id = tokenizer.get_pad_id();
let max_position_embeddings = config.n_positions;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = config.decoder_start_token_id;
Ok(GPT2Generator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
}
impl PrivateLanguageGenerator for GPT2Generator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> Option<i64> {
self.bos_token_id
}
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
self.eos_token_ids.as_ref()
}
fn get_pad_id(&self) -> Option<i64> {
self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
match layer_past {
Cache::GPT2Cache(layer_past) => self.model.forward_t(
input_ids,
layer_past.as_ref(),
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Cache::None => self.model.forward_t(
input_ids,
None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
_ => Err(RustBertError::ValueError(
"Cache not compatible with GPT2 Model".into(),
)),
}
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
.masked_fill(&attention_mask.eq(0), 1);
match past {
Cache::GPT2Cache(past) => {
if past.is_some() {
PreparedInput {
prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
prepared_past: Cache::GPT2Cache(past),
}
} else {
PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPT2Cache(None),
}
}
}
Cache::None => PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPT2Cache(None),
},
_ => panic!("Cache type incompatible with GPT2"),
}
}
fn reorder_cache(
&self,
past: &mut Cache,
_encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
match past {
Cache::GPT2Cache(cached_decoder_state) => match cached_decoder_state {
Some(value) => {
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None
}
None => None,
},
Cache::None => None,
_ => {
panic!("Invalid cache for GPT2 model");
}
}
}
}
impl LanguageGenerator for GPT2Generator {}