use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::openai_gpt::transformer::Block;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::OpenAiGptTokenizer;
use rust_tokenizers::vocab::OpenAiGptVocab;
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
pub struct OpenAiGptModelResources;
pub struct OpenAiGptConfigResources;
pub struct OpenAiGptVocabResources;
pub struct OpenAiGptMergesResources;
impl OpenAiGptModelResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/model",
"https://huggingface.co/openai-gpt/resolve/main/rust_model.ot",
);
}
impl OpenAiGptConfigResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/config",
"https://huggingface.co/openai-gpt/resolve/main/config.json",
);
}
impl OpenAiGptVocabResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/vocab",
"https://huggingface.co/openai-gpt/resolve/main/vocab.json",
);
}
impl OpenAiGptMergesResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/merges",
"https://huggingface.co/openai-gpt/resolve/main/merges.txt",
);
}
pub struct OpenAiGptModel {
tokens_embed: nn::Embedding,
positions_embed: nn::Embedding,
drop: Dropout,
h: Vec<Block>,
output_hidden_states: bool,
output_attentions: bool,
}
impl OpenAiGptModel {
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAiGptModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let tokens_embed = embedding(
p / "tokens_embed",
config.vocab_size,
config.n_embd,
Default::default(),
);
let positions_embed = embedding(
p / "positions_embed",
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
let drop = Dropout::new(embd_pdrop);
let mut h: Vec<Block> = vec![];
let h_path = p / "h";
for layer_index in 0..config.n_layer {
h.push(Block::new(&h_path / layer_index, config, true));
}
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
OpenAiGptModel {
tokens_embed,
positions_embed,
drop,
h,
output_hidden_states,
output_attentions,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<OpenAiGptModelOutput, RustBertError> {
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.tokens_embed)?;
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_shape[1];
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
};
let attention_mask = attention_mask.as_ref().map(|value| {
((value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0)
* 10000.0)
.to_kind(input_embeddings.kind())
});
let position_embeds = position_ids.apply(&self.positions_embed);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.tokens_embed),
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
for layer in &self.h {
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.1.as_ref().unwrap().copy());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
}
Ok(OpenAiGptModelOutput {
hidden_state,
all_hidden_states,
all_attentions,
})
}
}
pub struct OpenAIGPTLMHeadModel {
transformer: OpenAiGptModel,
lm_head: LinearNoBias,
}
impl OpenAIGPTLMHeadModel {
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAIGPTLMHeadModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transformer = OpenAiGptModel::new(p, config);
let lm_head = linear_no_bias(
p / "lm_head",
config.n_embd,
config.vocab_size,
Default::default(),
);
OpenAIGPTLMHeadModel {
transformer,
lm_head,
}
}
}
impl LMHeadModel for OpenAIGPTLMHeadModel {
fn forward_t(
&self,
input_ids: Option<&Tensor>,
_layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = self.transformer.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let lm_logits = base_model_output.hidden_state.apply(&self.lm_head);
Ok(LMModelOutput {
lm_logits,
cache: Cache::None,
})
}
}
pub struct OpenAiGptModelOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct OpenAIGenerator {
model: OpenAIGPTLMHeadModel,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl OpenAIGenerator {
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
generate_config.validate();
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device;
let mut var_store = nn::VarStore::new(device);
let tokenizer = TokenizerOption::from_file(
ModelType::OpenAiGpt,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
true,
None,
None,
)?;
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = None;
let eos_token_ids = None;
let pad_token_id = None;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
let max_position_embeddings = config.n_positions;
Ok(OpenAIGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
}
impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTokenizer>
for OpenAIGenerator
{
fn get_model(&self) -> &OpenAIGPTLMHeadModel {
&self.model
}
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
}
impl LanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTokenizer>
for OpenAIGenerator
{
}