use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::gpt_neo::decoder::GptNeoBlock;
use crate::gpt_neo::LayerState;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Activation, Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Device, Kind, Tensor};
pub struct GptNeoModelResources;
pub struct GptNeoConfigResources;
pub struct GptNeoVocabResources;
pub struct GptNeoMergesResources;
impl GptNeoModelResources {
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/model",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/rust_model.ot",
);
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/model",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/rust_model.ot",
);
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/model",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/rust_model.ot",
);
}
impl GptNeoConfigResources {
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/config",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json",
);
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/config",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/config.json",
);
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/config",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/config.json",
);
}
impl GptNeoVocabResources {
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json",
);
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/vocab.json",
);
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/vocab",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/vocab.json",
);
}
impl GptNeoMergesResources {
pub const GPT_NEO_125M: (&'static str, &'static str) = (
"gpt-neo-125M/merges",
"https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt",
);
pub const GPT_NEO_1_3B: (&'static str, &'static str) = (
"gpt-neo-1_3B/merges",
"https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/merges.txt",
);
pub const GPT_NEO_2_7B: (&'static str, &'static str) = (
"gpt-neo-2_7B/merges",
"https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/merges.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub enum AttentionLayerType {
Global,
Local,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GptNeoConfig {
pub activation_function: Activation,
pub attention_dropout: f64,
pub attention_layers: Vec<AttentionLayerType>,
pub attention_types: Vec<(Vec<AttentionLayerType>, i64)>,
pub intermediate_size: Option<i64>,
pub bos_token_id: i64,
pub eos_token_id: i64,
pub forced_bos_token_id: Option<i64>,
pub forced_eos_token_id: Option<i64>,
pub vocab_size: i64,
pub num_layers: i64,
pub num_heads: i64,
pub hidden_size: i64,
pub window_size: i64,
pub embed_dropout: f64,
pub initializer_range: f64,
pub layer_norm_epsilon: f64,
pub max_position_embeddings: i64,
pub output_past: Option<bool>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub resid_dropout: f64,
pub decoder_start_token_id: Option<i64>,
}
impl Config for GptNeoConfig {}
impl Default for GptNeoConfig {
fn default() -> Self {
GptNeoConfig {
activation_function: Activation::gelu_new,
attention_dropout: 0.0,
attention_layers: [AttentionLayerType::Global, AttentionLayerType::Local]
.iter()
.cycle()
.take(24)
.map(|layer_type| layer_type.to_owned())
.collect::<Vec<AttentionLayerType>>(),
attention_types: vec![(
vec![AttentionLayerType::Global, AttentionLayerType::Local],
12,
)],
intermediate_size: None,
bos_token_id: 50256,
eos_token_id: 50256,
forced_bos_token_id: None,
forced_eos_token_id: None,
vocab_size: 50257,
num_layers: 24,
num_heads: 16,
hidden_size: 2048,
window_size: 256,
embed_dropout: 0.0,
initializer_range: 0.02,
layer_norm_epsilon: 1e-5,
max_position_embeddings: 2048,
output_past: None,
output_attentions: None,
output_hidden_states: None,
resid_dropout: 0.0,
decoder_start_token_id: None,
}
}
}
pub struct GptNeoModel {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
layers: Vec<GptNeoBlock>,
dropout: Dropout,
layer_norm: nn::LayerNorm,
output_attentions: bool,
output_hidden_states: bool,
}
impl GptNeoModel {
pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoModel, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let word_embeddings = nn::embedding(
p / "wte",
config.vocab_size,
config.hidden_size,
Default::default(),
);
let position_embeddings = nn::embedding(
p / "wpe",
config.max_position_embeddings,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.embed_dropout);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let layer_norm = nn::layer_norm(p / "ln_f", vec![config.hidden_size], layer_norm_config);
let mut layers: Vec<GptNeoBlock> = Vec::with_capacity(config.num_layers as usize);
let p_layers = p / "h";
for layer_index in 0..config.num_layers {
layers.push(GptNeoBlock::new(
&p_layers / layer_index,
layer_index as usize,
config,
));
}
let output_attentions = config.output_attentions.unwrap_or(false);
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
Ok(GptNeoModel {
word_embeddings,
position_embeddings,
layers,
dropout,
layer_norm,
output_attentions,
output_hidden_states,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
layer_states: Option<Vec<Option<LayerState>>>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<GptNeoModelOutput, RustBertError> {
let (calc_input_embeddings, input_shape, device) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let (batch_size, current_sequence_length) = (input_shape[0], input_shape[1]);
let past_length = if let Some(past_state_value) = &layer_states {
if let Some(first_layer_state) = &past_state_value[0] {
let mut size_iter = first_layer_state.prev_key.size().into_iter().rev();
size_iter.next();
size_iter.next().unwrap()
} else {
0
}
} else {
0
};
let full_sequence_length = current_sequence_length + past_length;
let calc_position_ids = if position_ids.is_none() {
let position_ids =
Tensor::arange_start(past_length, full_sequence_length, (Kind::Int64, device));
Some(
position_ids
.unsqueeze(0)
.view([-1, current_sequence_length]),
)
} else {
None
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let position_embeds = position_ids.apply(&self.position_embeddings);
let attention_mask = attention_mask.map(|attention_mask_value| {
let attention_mask = attention_mask_value
.view([batch_size, -1])
.unsqueeze(1)
.unsqueeze(1);
let attention_mask = attention_mask.to_kind(position_embeds.kind());
(1 - attention_mask) * -1e4
});
let mut hidden_state = input_embeds + position_embeds;
if let Some(token_type_ids) = token_type_ids {
hidden_state = hidden_state + token_type_ids.apply(&self.word_embeddings);
};
hidden_state = hidden_state.apply_t(&self.dropout, train);
let mut output_shape = input_shape;
output_shape.push(*hidden_state.size().last().unwrap());
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let old_cache = layer_states.unwrap_or_else(|| vec![None; self.layers.len()]);
let mut next_cache = vec![None; self.layers.len()];
let mut x: Option<Tensor> = None;
let mut attention_weights: Option<Tensor>;
for ((layer_idx, layer), layer_state) in
self.layers.iter().enumerate().zip(old_cache.into_iter())
{
let temp = if let Some(x_value) = &x {
layer.forward_t(
x_value,
layer_state.as_ref(),
attention_mask.as_ref(),
train,
)?
} else {
layer.forward_t(
&hidden_state,
layer_state.as_ref(),
attention_mask.as_ref(),
train,
)?
};
x = Some(temp.0);
attention_weights = temp.1;
next_cache[layer_idx] = temp.2;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(std::mem::take(&mut attention_weights.unwrap()));
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(x.as_ref().unwrap().copy());
};
}
let hidden_states = x
.unwrap()
.apply(&self.layer_norm)
.view(output_shape.as_slice());
Ok(GptNeoModelOutput {
hidden_states,
next_cache: Some(next_cache),
all_hidden_states,
all_attentions,
})
}
}
pub struct GptNeoForCausalLM {
transformer: GptNeoModel,
}
impl GptNeoForCausalLM {
pub fn new<'p, P>(p: P, config: &GptNeoConfig) -> Result<GptNeoForCausalLM, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transformer = GptNeoModel::new(p / "transformer", config)?;
Ok(GptNeoForCausalLM { transformer })
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
layer_states: Option<Vec<Option<LayerState>>>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<GptNeoModelLMOutput, RustBertError> {
let base_model_output = self.transformer.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
layer_states,
attention_mask,
train,
)?;
let lm_logits = base_model_output
.hidden_states
.linear::<Tensor>(&self.transformer.word_embeddings.ws, None);
Ok(GptNeoModelLMOutput {
lm_logits,
next_cache: base_model_output.next_cache,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub struct GptNeoModelOutput {
pub hidden_states: Tensor,
pub next_cache: Option<Vec<Option<LayerState>>>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct GptNeoModelLMOutput {
pub lm_logits: Tensor,
pub next_cache: Option<Vec<Option<LayerState>>>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct GptNeoGenerator {
model: GptNeoForCausalLM,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl GptNeoGenerator {
pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT-Neo expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::GPTNeo,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
false,
None,
None,
)?;
Self::new_with_tokenizer(generate_config, tokenizer)
}
pub fn new_with_tokenizer(
generate_config: GenerateConfig,
tokenizer: TokenizerOption,
) -> Result<GptNeoGenerator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
let pad_token_id = tokenizer.get_pad_id();
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = config.decoder_start_token_id;
let max_position_embeddings = config.max_position_embeddings;
Ok(GptNeoGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
}
impl PrivateLanguageGenerator for GptNeoGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_device(&self) -> Device {
self.var_store.device()
}
fn get_var_store_mut(&mut self) -> Result<&mut nn::VarStore, RustBertError> {
Ok(&mut self.var_store)
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> Option<i64> {
self.bos_token_id
}
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
self.eos_token_ids.as_ref()
}
fn get_pad_id(&self) -> Option<i64> {
self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> Option<i64> {
Some(self.max_position_embeddings)
}
fn forward_t(
&self,
input_ids: Option<&Tensor>,
layer_past: Cache,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPTNeoCache(layer_past) => self.model.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
layer_past,
attention_mask,
train,
),
Cache::None => self.model.forward_t(
input_ids,
input_embeds,
token_type_ids,
position_ids,
None,
attention_mask,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with GPT-Neo Model".into(),
));
}
}?;
Ok(LMModelOutput {
lm_logits: base_model_output.lm_logits,
cache: Cache::GPTNeoCache(base_model_output.next_cache),
})
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
let position_ids = (attention_mask.totype(Kind::Int64).cumsum(-1, Kind::Int64) - 1)
.masked_fill(&attention_mask.eq(0), 1);
match past {
Cache::GPTNeoCache(past) => {
if past.is_some() {
PreparedInput {
prepared_input: Some(input_ids.select(1, -1).unsqueeze(-1)),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids.select(1, -1).unsqueeze(-1)),
prepared_past: Cache::GPTNeoCache(past),
}
} else {
PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPTNeoCache(None),
}
}
}
Cache::None => PreparedInput {
prepared_input: Some(input_ids),
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: None,
prepared_decoder_input: None,
prepared_position_ids: Some(position_ids),
prepared_past: Cache::GPTNeoCache(None),
},
_ => panic!("Cache type incompatible with GPT-Neo"),
}
}
fn reorder_cache(
&self,
past: &mut Cache,
_encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
match past {
Cache::GPTNeoCache(cached_decoder_state) => match cached_decoder_state {
Some(old_cache) => {
for layer_state in old_cache.iter_mut() {
if layer_state.is_some() {
layer_state.as_mut().unwrap().reorder_cache(beam_indices)
};
}
None
}
None => None,
},
Cache::None => None,
_ => {
panic!("Invalid cache for GPT-Neo model");
}
}
}
}
impl LanguageGenerator for GptNeoGenerator {}