use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::Gpt2Config;
use crate::openai_gpt::transformer::Block;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Device, Tensor};
pub struct OpenAiGptModelResources;
pub struct OpenAiGptConfigResources;
pub struct OpenAiGptVocabResources;
pub struct OpenAiGptMergesResources;
impl OpenAiGptModelResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/model",
"https://huggingface.co/openai-gpt/resolve/main/rust_model.ot",
);
}
impl OpenAiGptConfigResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/config",
"https://huggingface.co/openai-gpt/resolve/main/config.json",
);
}
impl OpenAiGptVocabResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/vocab",
"https://huggingface.co/openai-gpt/resolve/main/vocab.json",
);
}
impl OpenAiGptMergesResources {
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/merges",
"https://huggingface.co/openai-gpt/resolve/main/merges.txt",
);
}
pub type OpenAiGptConfig = Gpt2Config;
pub struct OpenAiGptModel {
tokens_embed: nn::Embedding,
positions_embed: nn::Embedding,
drop: Dropout,
h: Vec<Block>,
output_hidden_states: bool,
output_attentions: bool,
}
impl OpenAiGptModel {
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAiGptModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let tokens_embed = embedding(
p / "tokens_embed",
config.vocab_size,
config.n_embd,
Default::default(),
);
let positions_embed = embedding(
p / "positions_embed",
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
let drop = Dropout::new(embd_pdrop);
let mut h: Vec<Block> = vec![];
let h_path = p / "h";
for layer_index in 0..config.n_layer {
h.push(Block::new(&h_path / layer_index, config, true));
}
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
OpenAiGptModel {
tokens_embed,
positions_embed,
drop,
h,
output_hidden_states,
output_attentions,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<OpenAiGptModelOutput, RustBertError> {
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.tokens_embed)?;
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_shape[1];
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
};
let attention_mask = attention_mask.as_ref().map(|value| {
((value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0)
* 10000.0)
.to_kind(input_embeddings.kind())
});
let position_embeds = position_ids.apply(&self.positions_embed);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.tokens_embed),
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
for layer in &self.h {
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.1.unwrap());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
}
Ok(OpenAiGptModelOutput {
hidden_state,
all_hidden_states,
all_attentions,
})
}
}
pub struct OpenAIGPTLMHeadModel {
transformer: OpenAiGptModel,
lm_head: LinearNoBias,
}
impl OpenAIGPTLMHeadModel {
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAIGPTLMHeadModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transformer = OpenAiGptModel::new(p, config);
let lm_head = linear_no_bias(
p / "lm_head",
config.n_embd,
config.vocab_size,
Default::default(),
);
OpenAIGPTLMHeadModel {
transformer,
lm_head,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
_layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = self.transformer.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let lm_logits = base_model_output.hidden_state.apply(&self.lm_head);
Ok(LMModelOutput {
lm_logits,
cache: Cache::None,
})
}
}
pub struct OpenAiGptModelOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct OpenAIGenerator {
model: OpenAIGPTLMHeadModel,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl OpenAIGenerator {
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::OpenAiGpt,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
true,
None,
None,
)?;
Self::new_with_tokenizer(generate_config, tokenizer)
}
pub fn new_with_tokenizer(
generate_config: GenerateConfig,
tokenizer: TokenizerOption,
) -> Result<OpenAIGenerator, RustBertError> {
generate_config.validate();
let config_path = generate_config.config_resource.get_local_path()?;
let device = generate_config.device;
let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
let pad_token_id = tokenizer.get_pad_id();
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.n_positions;
Ok(OpenAIGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
}
impl PrivateLanguageGenerator for OpenAIGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> Option<i64> {
self.bos_token_id
}
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
self.eos_token_ids.as_ref()
}
fn get_pad_id(&self) -> Option<i64> {
self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
&self,
input_ids: Option<&Tensor>,
_layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
self.model.forward_t(
input_ids,
_layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
_encoder_outputs,
_decoder_input_ids,
train,
)
}
}
impl LanguageGenerator for OpenAIGenerator {}