use crate::bart::BartModelOutput;
use crate::common::kind::get_negative_infinity;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{Gpt2ConfigResources, Gpt2ModelResources, Gpt2VocabResources};
use crate::mbart::MBartConfig;
use crate::pegasus::decoder::PegasusDecoder;
use crate::pegasus::encoder::PegasusEncoder;
use crate::pegasus::LayerState;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{PegasusTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::PegasusVocab;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig, Init};
use tch::{nn, Tensor};
pub struct PegasusModelResources;
pub struct PegasusConfigResources;
pub struct PegasusVocabResources;
impl PegasusModelResources {
pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
"pegasus-cnn_dailymail/model",
"https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/rust_model.ot",
);
}
impl PegasusConfigResources {
pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
"pegasus-cnn_dailymail/config",
"https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/config.json",
);
}
impl PegasusVocabResources {
pub const CNN_DAILYMAIL: (&'static str, &'static str) = (
"pegasus-cnn_dailymail/spiece",
"https://huggingface.co/google/pegasus-cnn_dailymail/resolve/main/spiece.model",
);
}
pub type PegasusConfig = MBartConfig;
fn _shift_tokens_right(
input_ids: &Tensor,
pad_token_id: i64,
decoder_start_token_id: i64,
) -> Tensor {
let input_ids_length = input_ids.size()[1];
let mut shifted_input_ids = Tensor::zeros(
input_ids.size().as_slice(),
(input_ids.kind(), input_ids.device()),
);
let _ = shifted_input_ids
.slice(1, 1, input_ids_length, 1)
.copy_(&input_ids.slice(1, 0, input_ids_length - 1, 1));
let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id);
let _ = shifted_input_ids.masked_fill_(&shifted_input_ids.eq(-100), pad_token_id);
shifted_input_ids
}
pub struct PegasusModel {
pub(crate) encoder: PegasusEncoder,
decoder: PegasusDecoder,
pub(crate) embeddings: nn::Embedding,
}
impl PegasusModel {
pub fn new<'p, P>(p: P, config: &PegasusConfig) -> PegasusModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let pad_token_id = config.pad_token_id.unwrap_or(0);
let embedding_config = EmbeddingConfig {
padding_idx: pad_token_id,
..Default::default()
};
let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
embedding_config,
);
let encoder = PegasusEncoder::new(p / "encoder", config);
let decoder = PegasusDecoder::new(p / "decoder", config);
PegasusModel {
encoder,
decoder,
embeddings,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: &Tensor,
encoder_output: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> PegasusModelOutput {
let calc_encoder_output = if encoder_output.is_none() {
Some(self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
))
} else {
None
};
let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
if let Some(calc_encoder_output) = calc_encoder_output {
(
Some(calc_encoder_output.hidden_state),
calc_encoder_output.all_hidden_states,
calc_encoder_output.all_attentions,
)
} else {
(None, None, None)
};
let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
let decoder_output = self.decoder.forward_t(
decoder_input_ids,
encoder_output,
attention_mask,
decoder_attention_mask,
&self.embeddings,
layer_states,
train,
);
PegasusModelOutput {
decoder_output: decoder_output.hidden_state,
encoder_hidden_state: calc_hidden_states,
cache: decoder_output.next_decoder_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
all_decoder_attentions: decoder_output.all_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
}
}
}
pub struct PegasusForConditionalGeneration {
base_model: PegasusModel,
final_logits_bias: Tensor,
pad_token_id: i64,
decoder_start_token_id: i64,
}
impl PegasusForConditionalGeneration {
pub fn new<'p, P>(p: P, config: &PegasusConfig) -> PegasusForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = PegasusModel::new(p / "model", config);
let final_logits_bias = p.var(
"final_logits_bias",
&[1, config.vocab_size],
Init::Const(0.0),
);
let pad_token_id = config.pad_token_id.unwrap_or(0);
let decoder_start_token_id = config.decoder_start_token_id.unwrap_or(0);
PegasusForConditionalGeneration {
base_model,
final_logits_bias,
pad_token_id,
decoder_start_token_id,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_output: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> PegasusModelOutput {
let calc_decoder_input_ids = if decoder_input_ids.is_none() {
Some(_shift_tokens_right(
input_ids.unwrap(),
self.pad_token_id,
self.decoder_start_token_id,
))
} else {
None
};
let decoder_input_ids =
decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap());
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_output,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
+ &self.final_logits_bias;
PegasusModelOutput {
decoder_output: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
self.base_model
.encoder
.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
)
.hidden_state
}
}
impl LMHeadModel for PegasusForConditionalGeneration {
fn forward_t(
&self,
input_ids: Option<&Tensor>,
cache: Cache,
attention_mask: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
_position_ids: Option<&Tensor>,
_input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids.ok_or_else(|| {
RustBertError::ValueError(
"Decoder input ids must be provided for Pegasus language models"
.to_string(),
)
})?,
encoder_outputs,
None,
cached_layer_states,
train,
),
Cache::None => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids.ok_or_else(|| {
RustBertError::ValueError(
"Decoder input ids must be provided for Pegasus language models"
.to_string(),
)
})?,
encoder_outputs,
None,
None,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with Pegasus Model".into(),
));
}
};
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
+ &self.final_logits_bias;
Ok(LMModelOutput {
lm_logits,
cache: Cache::BARTCache(base_model_output.cache),
})
}
}
pub struct PegasusConditionalGenerator {
model: PegasusForConditionalGeneration,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl PegasusConditionalGenerator {
pub fn new(
generate_config: GenerateConfig,
) -> Result<PegasusConditionalGenerator, RustBertError> {
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusModelResources::CNN_DAILYMAIL,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
PegasusVocabResources::CNN_DAILYMAIL,
))
} else {
generate_config.vocab_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = TokenizerOption::from_file(
ModelType::Pegasus,
vocab_path.to_str().unwrap(),
None,
false,
None,
None,
)?;
let config = PegasusConfig::from_file(config_path);
let model = PegasusForConditionalGeneration::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(0);
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec![value],
None => vec![1],
});
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(0);
let max_position_embeddings = config.max_position_embeddings;
Ok(PegasusConditionalGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(
1,
&impossible_tokens,
get_negative_infinity(scores.kind()).unwrap(),
);
}
}
impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, PegasusTokenizer>
for PegasusConditionalGenerator
{
fn get_model(&self) -> &PegasusForConditionalGeneration {
&self.model
}
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
_forced_bos_token_id: Option<i64>,
) {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
match past {
Cache::BARTCache(past) => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
prepared_position_ids: None,
prepared_past: Cache::BARTCache(past),
},
Cache::None => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids),
prepared_position_ids: None,
prepared_past: Cache::BARTCache(None),
},
_ => panic!("Cache type incompatible with Pegasus"),
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0],
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
}
}
None => {}
},
Cache::None => {}
_ => {
panic!("Invalid cache for Pegasus model");
}
};
encoder_outputs
}
}
impl LanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, PegasusTokenizer>
for PegasusConditionalGenerator
{
}
pub type PegasusModelOutput = BartModelOutput;