use self::ordered_float::OrderedFloat;
use crate::bart::{
BartConfig, BartConfigResources, BartForConditionalGeneration, BartMergesResources,
BartModelResources, BartVocabResources, LayerState,
};
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use crate::marian::MarianForConditionalGeneration;
use crate::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
use crate::Config;
use itertools::Itertools;
use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
use rust_tokenizers::preprocessing::vocab::marian_vocab::MarianVocab;
use rust_tokenizers::{
Gpt2Tokenizer, Gpt2Vocab, OpenAiGptTokenizer, OpenAiGptVocab, RobertaTokenizer, RobertaVocab,
Tokenizer, TruncationStrategy, Vocab,
};
use tch::kind::Kind::Int64;
use tch::{nn, no_grad, Device, Tensor};
extern crate ordered_float;
pub struct GenerateConfig {
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Resource,
pub min_length: u64,
pub max_length: u64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: u64,
pub temperature: f64,
pub top_k: u64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: u64,
pub num_return_sequences: u64,
pub device: Device,
}
impl Default for GenerateConfig {
fn default() -> GenerateConfig {
GenerateConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::GPT2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::GPT2,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
)),
min_length: 0,
max_length: 20,
do_sample: true,
early_stopping: false,
num_beams: 5,
temperature: 1.0,
top_k: 0,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_return_sequences: 1,
device: Device::cuda_if_available(),
}
}
}
impl GenerateConfig {
fn validate(&self) {
assert!(self.temperature > 0f64, "temperature must positive");
assert!(
(self.top_p >= 0f64) & (self.top_p <= 1f64),
"top_p must be 0 and 1"
);
assert!(
self.repetition_penalty >= 1f64,
"repetition_penalty must be greater than 1"
);
assert!(
self.length_penalty > 0f64,
"length_penalty must be strictly greater than 0"
);
assert!(
self.num_return_sequences > 0u64,
"num_return_sequences must be strictly greater than 0"
);
assert!(
self.num_beams > 0u64,
"num_beams must be strictly greater than 0"
);
if !self.do_sample {
if self.num_beams == 1 {
assert_eq!(
self.num_return_sequences, 1,
"num_return_sequences must be set to 1 for greedy decoding"
)
} else {
assert!(
self.num_beams >= self.num_return_sequences,
"num_return_sequences must be lower than the number of beams"
)
}
}
}
}
pub struct OpenAIGenerator {
model: OpenAIGPTLMHeadModel,
tokenizer: OpenAiGptTokenizer,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
}
impl OpenAIGenerator {
pub fn new(generate_config: GenerateConfig) -> failure::Fallible<OpenAIGenerator> {
generate_config.validate();
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&model_resource)?;
let device = generate_config.device;
let mut var_store = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = None;
let eos_token_ids = None;
let pad_token_id = None;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
Ok(OpenAIGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
})
}
}
impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTokenizer>
for OpenAIGenerator
{
fn get_model(&self) -> &OpenAIGPTLMHeadModel {
&self.model
}
fn get_tokenizer(&self) -> &OpenAiGptTokenizer {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
}
impl LanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTokenizer>
for OpenAIGenerator
{
}
pub struct GPT2Generator {
model: GPT2LMHeadModel,
tokenizer: Gpt2Tokenizer,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
}
impl GPT2Generator {
pub fn new(generate_config: GenerateConfig) -> failure::Fallible<GPT2Generator> {
let config_path = download_resource(&generate_config.config_resource)?;
let vocab_path = download_resource(&generate_config.vocab_resource)?;
let merges_path = download_resource(&generate_config.merges_resource)?;
let weights_path = download_resource(&generate_config.model_resource)?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(tokenizer.vocab().token_to_id(Gpt2Vocab::bos_value()));
let eos_token_ids = Some(vec![tokenizer.vocab().token_to_id(Gpt2Vocab::eos_value())]);
let pad_token_id = None;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
Ok(GPT2Generator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
})
}
}
impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Generator {
fn get_model(&self) -> &GPT2LMHeadModel {
&self.model
}
fn get_tokenizer(&self) -> &Gpt2Tokenizer {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Cache,
_attention_mask: Tensor,
) -> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
match past {
Cache::GPT2Cache(past) => {
if past.is_some() {
(
Some(input_ids.select(1, -1).unsqueeze(-1)),
None,
None,
Cache::GPT2Cache(past),
)
} else {
(Some(input_ids), None, None, Cache::GPT2Cache(None))
}
}
Cache::None => (Some(input_ids), None, None, Cache::GPT2Cache(None)),
_ => panic!("Cache type incompatible with GPT2"),
}
}
}
impl LanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Generator {}
pub struct BartGenerator {
model: BartForConditionalGeneration,
tokenizer: RobertaTokenizer,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
}
impl BartGenerator {
pub fn new(generate_config: GenerateConfig) -> failure::Fallible<BartGenerator> {
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART))
} else {
generate_config.merges_resource.clone()
};
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&model_resource)?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(&var_store.root(), &config, true);
var_store.load(weights_path)?;
let bos_token_id = Some(0);
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec![value],
None => vec![2],
});
let pad_token_id = Some(match config.pad_token_id {
Some(value) => value,
None => 1,
});
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
Ok(BartGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, std::f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer>
for BartGenerator
{
fn get_model(&self) -> &BartForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &RobertaTokenizer {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &vec![self.get_bos_id().unwrap()]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
past: Cache,
_attention_mask: Tensor,
) -> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
match past {
Cache::BARTCache(past) => (
None,
encoder_outputs,
Some(input_ids),
Cache::BARTCache(past),
),
Cache::None => (
None,
encoder_outputs,
Some(input_ids),
Cache::BARTCache(None),
),
_ => panic!("Cache type incompatible with BART"),
}
}
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self
.get_tokenizer()
.vocab()
.token_to_id(RobertaVocab::unknown_value()),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None,
};
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
},
Cache::None => {}
_ => {
panic!("Invalid cache for BART model");
}
};
encoder_outputs
}
}
impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer>
for BartGenerator
{
}
pub struct MarianGenerator {
model: MarianForConditionalGeneration,
tokenizer: MarianTokenizer,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
}
impl MarianGenerator {
pub fn new(generate_config: GenerateConfig) -> failure::Fallible<MarianGenerator> {
let config_path = download_resource(&generate_config.config_resource)?;
let vocab_path = download_resource(&generate_config.vocab_resource)?;
let sentence_piece_path = download_resource(&generate_config.merges_resource)?;
let weights_path = download_resource(&generate_config.model_resource)?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = MarianTokenizer::from_files(
vocab_path.to_str().unwrap(),
sentence_piece_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path);
let model = MarianForConditionalGeneration::new(&var_store.root(), &config, true);
var_store.load(weights_path)?;
let bos_token_id = Some(0);
let eos_token_ids = Some(vec![tokenizer
.vocab()
.token_to_id(MarianVocab::eos_value())]);
let pad_token_id = Some(tokenizer.vocab().token_to_id(MarianVocab::pad_value()));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(tokenizer.vocab().token_to_id(MarianVocab::pad_value()));
Ok(MarianGenerator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, std::f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, MarianTokenizer>
for MarianGenerator
{
fn get_model(&self) -> &MarianForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &MarianTokenizer {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
) {
let _ = scores.index_fill_(
1,
&Tensor::of_slice(&[self.get_pad_id().unwrap()])
.to_kind(Int64)
.to_device(scores.device()),
std::f64::NEG_INFINITY,
);
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
past: Cache,
_attention_mask: Tensor,
) -> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
match past {
Cache::BARTCache(past) => (
None,
encoder_outputs,
Some(input_ids),
Cache::BARTCache(past),
),
Cache::None => (
None,
encoder_outputs,
Some(input_ids),
Cache::BARTCache(None),
),
_ => panic!("Cache type incompatible with Marian"),
}
}
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self
.get_tokenizer()
.vocab()
.token_to_id(RobertaVocab::unknown_value()),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None,
};
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
},
Cache::None => {}
_ => {
panic!("Invalid cache for BART model");
}
};
encoder_outputs
}
}
impl LanguageGenerator<MarianForConditionalGeneration, MarianVocab, MarianTokenizer>
for MarianGenerator
{
}
#[derive(Debug)]
pub enum Cache {
GPT2Cache(Option<Vec<Tensor>>),
BARTCache(Option<Vec<(Option<LayerState>, Option<LayerState>)>>),
None,
}
pub(crate) mod private_generation_utils {
use super::ordered_float::OrderedFloat;
use crate::pipelines::generation::{BeamHypotheses, Cache, GenerateConfig, LMHeadModel};
use itertools::Itertools;
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
use rust_tokenizers::{Tokenizer, TruncationStrategy, Vocab};
use std::cmp::{max, min};
use std::collections::HashMap;
use tch::kind::Kind::{Bool, Float, Int64};
use tch::{nn, Device, Tensor};
pub trait PrivateLanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
fn get_model(&self) -> &T;
fn get_tokenizer(&self) -> &U;
fn get_var_store(&self) -> &nn::VarStore;
fn get_config(&self) -> &GenerateConfig;
fn get_bos_id(&self) -> &Option<i64>;
fn get_eos_ids(&self) -> &Option<Vec<i64>>;
fn get_pad_id(&self) -> &Option<i64>;
fn is_encoder_decoder(&self) -> bool;
fn get_vocab_size(&self) -> i64;
fn get_decoder_start_id(&self) -> Option<i64>;
fn prepare_scores_for_generation(
&self,
_scores: &mut Tensor,
_current_length: i64,
_max_length: i64,
) {
}
fn encode(&self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> {
None
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Cache,
_attention_mask: Tensor,
) -> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
(Some(input_ids), None, None, past)
}
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().tokenize_list(prompt_text);
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| self.get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
.collect::<Vec<Vec<i64>>>();
let num_truncated_tokens = token_ids
.iter()
.map(|token_ids| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
})
.collect::<Vec<usize>>();
let token_ids = token_ids
.into_iter()
.zip(num_truncated_tokens)
.map(|(tokens, num_truncated_tokens)| {
truncate_sequences(
tokens,
None,
vec![],
None,
vec![],
None,
vec![],
None,
num_truncated_tokens,
&TruncationStrategy::LongestFirst,
0,
)
.unwrap()
.0
})
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self.get_tokenizer().vocab().token_to_id(V::unknown_value()),
};
let token_ids = token_ids
.into_iter()
.map(|input| {
let mut temp = vec![pad_token; max_len - input.len()];
temp.extend(input);
temp
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn enforce_repetition_penalty(
&self,
next_token_logits: &mut Tensor,
batch_size: i64,
num_beams: u64,
prev_output_tokens: &Tensor,
repetition_penalty: f64,
) {
for i in 0..(batch_size * num_beams as i64) {
for token_position in 0..prev_output_tokens.get(i).size()[0] {
let token = prev_output_tokens.get(i).int64_value(&[token_position]);
let updated_value = &next_token_logits.double_value(&[i, token]);
if updated_value < &0f64 {
&next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
.to_device(next_token_logits.device()),
updated_value * repetition_penalty,
);
} else {
&next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
.to_device(next_token_logits.device()),
updated_value / repetition_penalty,
);
}
}
}
}
fn get_banned_tokens(
&self,
input_ids: &Tensor,
no_repeat_ngram_size: i64,
cur_len: i64,
) -> Vec<Vec<i64>> {
if cur_len + 1 < no_repeat_ngram_size {
vec![vec![]]
} else {
let input_ids = input_ids.to(Device::Cpu);
let num_hypothesis = *input_ids.size().first().unwrap();
let mut banned_tokens: Vec<Vec<i64>> = Vec::with_capacity(num_hypothesis as usize);
for hypothesis_index in 0..num_hypothesis {
let hypothesis_input_ids = input_ids.get(hypothesis_index);
let mut generated_ngram: HashMap<Vec<i64>, Vec<i64>> = HashMap::new();
let input: Vec<i64> = (0..hypothesis_input_ids.size1().unwrap()).collect();
let hypothesis_input_ids = hypothesis_input_ids
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let query = &hypothesis_input_ids
[cur_len as usize + 1 - no_repeat_ngram_size as usize..]
.to_vec();
let ngram_indices: Vec<(i64, i64)> = input
.windows(no_repeat_ngram_size as usize)
.map(|win| (*win.first().unwrap(), *win.last().unwrap()))
.collect();
for ngram in ngram_indices.into_iter() {
let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
let value = *ngram.last().unwrap();
if generated_ngram.contains_key(&key) {
generated_ngram.get_mut(&key).unwrap().push(value)
} else {
generated_ngram.insert(key, vec![value]);
}
}
let hypothesis_banned_tokens = match generated_ngram.get(query) {
Some(banned_tokens) => banned_tokens.clone(),
None => vec![],
};
banned_tokens.push(hypothesis_banned_tokens);
}
banned_tokens
}
}
fn top_k_top_p_filtering(
&self,
logits: &mut Tensor,
top_k: i64,
top_p: f64,
min_tokens_to_keep: i64,
) {
let vocab_size = *logits.size().last().unwrap();
if top_k > 0 {
let top_k = vocab_size - min(max(top_k, min_tokens_to_keep), vocab_size);
let (_, indices_to_remove) = logits.topk(top_k, -1, false, false);
for index in 0..*logits.size().first().unwrap() {
&logits.get(index).index_fill_(
0,
&indices_to_remove.get(index),
std::f64::NEG_INFINITY,
);
}
}
if top_p < 1f64 {
let (sorted_logits, sorted_indices) = logits.sort(-1, true);
let cumulative_probabilities = sorted_logits.softmax(-1, Float).cumsum(-1, Float);
let mut sorted_indices_to_remove =
cumulative_probabilities.ge(top_p).to_kind(Int64);
if min_tokens_to_keep > 1 {
&sorted_indices_to_remove.index_fill_(
1,
&Tensor::arange1(0, min_tokens_to_keep + 1, (Int64, logits.device())),
0,
);
}
let _ = sorted_indices_to_remove.index_copy_(
1,
&Tensor::arange1(1, vocab_size, (Int64, logits.device())),
&sorted_indices_to_remove
.slice(1, 0, vocab_size - 1, 1)
.copy(),
);
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::of_slice(&[0])
.to_kind(Int64)
.to_device(sorted_indices_to_remove.device()),
0,
);
let indices_to_remove = sorted_indices_to_remove
.scatter(1, &sorted_indices, &sorted_indices_to_remove)
.to_kind(Bool);
let _ = logits.masked_fill_(&indices_to_remove, std::f64::NEG_INFINITY);
}
}
fn generate_no_beam_search(
&self,
input_ids: Tensor,
encoder_outputs: Option<Tensor>,
cur_len: i64,
min_length: i64,
max_length: i64,
do_sample: bool,
temperature: f64,
top_k: i64,
top_p: f64,
repetition_penalty: f64,
no_repeat_ngram_size: i64,
pad_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
batch_size: i64,
attention_mask: Tensor,
) -> Tensor {
let mut unfinished_sentences =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
let mut sentence_lengths: Tensor =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()))
* max_length as i64;
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut past: Cache = Cache::None;
let mut outputs: Tensor;
let mut current_length = cur_len;
while current_length < max_length {
let (
prepared_input,
prepared_encoder_output,
prepared_decoder_input,
prepared_past,
) = self.prepare_inputs_for_generation(
input_ids.copy(),
encoder_outputs.as_ref(),
past,
attention_mask.copy(),
);
let temp = self
.get_model()
.forward_t(
&prepared_input,
prepared_past,
&None,
&None,
&None,
&None,
prepared_encoder_output,
&prepared_decoder_input,
false,
)
.unwrap();
outputs = temp.0;
past = temp.2;
let mut next_token_logits = outputs.select(1, -1);
if repetition_penalty > 1f64 {
self.enforce_repetition_penalty(
&mut next_token_logits,
batch_size,
1,
&input_ids,
repetition_penalty,
)
}
if no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
no_repeat_ngram_size as i64,
current_length as i64,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
{
&next_token_logits.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
.to_device(next_token_logits.device()),
std::f64::NEG_INFINITY,
);
}
}
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&next_token_logits.index_fill_(
1,
&Tensor::of_slice(eos_token_ids.as_ref().unwrap())
.to(next_token_logits.device()),
std::f64::NEG_INFINITY,
);
}
let next_token = if do_sample {
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
}
self.top_k_top_p_filtering(&mut next_token_logits, top_k as i64, top_p, 1);
let probabilities = next_token_logits.softmax(-1, Float);
probabilities.multinomial(1, false).squeeze1(1)
} else {
next_token_logits.argmax(-1, false)
};
let tokens_to_add = match &eos_token_ids {
Some(_) => {
next_token * &unfinished_sentences
- pad_token_id.unwrap() * (&unfinished_sentences - 1)
}
None => next_token,
};
input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
if eos_token_ids.is_some() {
for eos_token_id in eos_token_ids.as_ref().unwrap() {
let sentence_with_eos = tokens_to_add.eq(*eos_token_id).to_kind(Int64);
let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
let _ = sentence_lengths.masked_fill_(
&sentence_with_eos
.to_kind(Bool)
.to_device(sentence_lengths.device()),
current_length as i64 + 1,
);
unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1);
}
if i64::from(unfinished_sentences.max()) == 0 {
break;
}
}
if !self.is_encoder_decoder() {
attention_mask = Tensor::cat(
&[
attention_mask.as_ref(),
Tensor::ones(
&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device()),
)
.as_ref(),
],
-1,
);
}
current_length += 1;
}
let decoded = if i64::from(&sentence_lengths.min().ne1(&sentence_lengths.max())) > 0 {
match pad_token_id {
Some(pad_value) => {
let decoded: Tensor = Tensor::ones(
&[batch_size, i64::from(sentence_lengths.max())],
(Int64, input_ids.device()),
) * pad_value;
for hypothesis_index in 0..*input_ids.size().first().unwrap() {
let _ = decoded.get(hypothesis_index).index_copy_(
0,
&Tensor::arange1(
0,
i64::from(sentence_lengths.get(hypothesis_index)),
(Int64, input_ids.device()),
),
&input_ids.get(hypothesis_index).slice(
0,
0,
i64::from(sentence_lengths.get(hypothesis_index)),
1,
),
);
}
decoded
}
None => input_ids,
}
} else {
input_ids
};
decoded
}
fn generate_beam_search(
&self,
input_ids: Tensor,
encoder_outputs: Option<Tensor>,
cur_len: i64,
min_length: i64,
max_length: i64,
do_sample: bool,
early_stopping: bool,
temperature: f64,
top_k: i64,
top_p: f64,
repetition_penalty: f64,
no_repeat_ngram_size: i64,
pad_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
batch_size: i64,
num_return_sequences: i64,
length_penalty: f64,
num_beams: i64,
attention_mask: Tensor,
) -> Tensor {
let mut hypotheses = (0..batch_size)
.map(|_| BeamHypotheses::new(num_beams, max_length, length_penalty, early_stopping))
.collect::<Vec<BeamHypotheses>>();
let vocab_size = self.get_vocab_size();
let beam_scores = Tensor::zeros(
&[batch_size, num_beams],
(Float, self.get_var_store().device()),
);
if !do_sample {
let _ = beam_scores
.slice(1, 1, *beam_scores.size().last().unwrap(), 1)
.fill_(-1e9);
}
let mut beam_scores = beam_scores.view_(&[-1]);
let mut beam_tokens: Tensor;
let mut beam_indices: Tensor;
let mut past: Cache = Cache::None;
let mut done = vec![false; batch_size as usize];
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut outputs: Tensor;
let mut encoder_outputs = encoder_outputs;
let mut current_length = cur_len;
while current_length < max_length {
let (
prepared_input,
prepared_encoder_output,
prepared_decoder_input,
prepared_past,
) = self.prepare_inputs_for_generation(
input_ids.copy(),
encoder_outputs.as_ref(),
past,
attention_mask.copy(),
);
let temp = self
.get_model()
.forward_t(
&prepared_input,
prepared_past,
&None,
&None,
&None,
&None,
prepared_encoder_output,
&prepared_decoder_input,
false,
)
.unwrap();
outputs = temp.0;
past = temp.2;
let mut next_token_logits = outputs.select(1, -1);
if repetition_penalty > 1f64 {
self.enforce_repetition_penalty(
&mut next_token_logits,
batch_size,
1,
&input_ids,
repetition_penalty,
)
}
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
}
let mut scores = next_token_logits.log_softmax(-1, Float);
if self.is_encoder_decoder() & !do_sample {
self.prepare_scores_for_generation(&mut scores, current_length, max_length);
}
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&scores.index_fill_(
1,
&Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(scores.device()),
std::f64::NEG_INFINITY,
);
}
if no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
no_repeat_ngram_size as i64,
current_length as i64,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
{
&scores.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
.to_device(next_token_logits.device()),
std::f64::NEG_INFINITY,
);
}
}
let (next_scores, next_tokens) = if do_sample {
let mut _scores: Tensor =
&scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
self.top_k_top_p_filtering(&mut _scores, top_k as i64, top_p, 2);
let _scores = _scores
.contiguous()
.view((batch_size, num_beams * vocab_size));
let probabilities = _scores.softmax(-1, Float);
let next_tokens = probabilities.multinomial(2 * num_beams, false);
let next_scores = _scores.gather(-1, &next_tokens, false);
let (next_scores, next_scores_indices) = next_scores.sort(1, true);
let next_tokens = next_tokens.gather(-1, &next_scores_indices, false);
(next_scores, next_tokens)
} else {
let next_scores: Tensor =
&scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
let next_scores = next_scores
.contiguous()
.view((batch_size, num_beams * vocab_size));
next_scores.topk(2 * num_beams, 1, true, true)
};
let mut next_batch_beam: Vec<(f64, i64, i64)> = vec![];
for batch_index in 0..batch_size {
if done[batch_index as usize] {
assert!(
hypotheses[batch_index as usize].len() >= num_beams,
"Batch cannot be completed if all beams have not been generated"
);
assert!(
eos_token_ids.is_some() & pad_token_id.is_some(),
"EOS and Padding tokens need to be defined if the number of generated \
beams is greater than the target number fo beams"
);
next_batch_beam.append(
&mut (0..num_beams)
.map(|_| (0f64, pad_token_id.unwrap(), 0i64))
.collect::<Vec<(f64, i64, i64)>>(),
);
continue;
}
let mut next_sentence_beam: Vec<(f64, i64, i64)> = vec![];
let mut beam_token_rank = 0;
let beam_token_rank_max_value =
*next_tokens.get(batch_index).size().first().unwrap() - 1;
loop {
let beam_token_id =
next_tokens.int64_value(&[batch_index, beam_token_rank]);
let beam_token_score =
next_scores.double_value(&[batch_index, beam_token_rank]);
let beam_id = beam_token_id / vocab_size;
let token_id = beam_token_id % vocab_size;
let effective_beam_id = batch_index * num_beams + beam_id;
if eos_token_ids.as_ref().is_some() {
if eos_token_ids.as_ref().unwrap().contains(&token_id) {
if beam_token_rank >= num_beams {
beam_token_rank += 1;
continue;
}
hypotheses[batch_index as usize]
.add(input_ids.get(effective_beam_id).copy(), beam_token_score)
} else {
next_sentence_beam.push((
beam_token_score,
token_id,
effective_beam_id,
));
}
} else {
next_sentence_beam.push((
beam_token_score,
token_id,
effective_beam_id,
));
}
if (next_sentence_beam.len() as i64 == num_beams)
| (beam_token_rank == beam_token_rank_max_value)
{
break;
}
beam_token_rank += 1;
}
done[batch_index as usize] = done[batch_index as usize]
| hypotheses[batch_index as usize].is_done(
f64::from(next_scores.get(batch_index).max()),
current_length,
);
assert_eq!(
next_sentence_beam.len() as i64,
num_beams,
"Beam incomplete"
);
next_batch_beam.append(&mut next_sentence_beam);
}
if done.iter().all(|&x| x) {
break;
}
beam_scores = Tensor::of_slice(
&next_batch_beam
.iter()
.map(|(score, _, _)| *score)
.collect_vec(),
)
.to(input_ids.device());
beam_tokens = Tensor::of_slice(
&next_batch_beam
.iter()
.map(|(_, token, _)| *token)
.collect_vec(),
)
.to(input_ids.device());
beam_indices = Tensor::of_slice(
&next_batch_beam
.iter()
.map(|(_, _, index)| *index)
.collect_vec(),
)
.to(input_ids.device());
input_ids = input_ids.index_select(0, &beam_indices);
input_ids = Tensor::cat(&[input_ids, beam_tokens.unsqueeze(1)], -1);
encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices);
if !self.is_encoder_decoder() {
attention_mask = Tensor::cat(
&[
attention_mask.as_ref(),
Tensor::ones(
&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device()),
)
.as_ref(),
],
-1,
);
}
current_length += 1;
}
let mut batch_index = 0i64;
loop {
if batch_index == batch_size {
break;
}
if done[batch_index as usize] {
batch_index += 1;
continue;
}
for beam_index in 0..num_beams {
let effective_beam_id = batch_index * num_beams + beam_index;
let final_score = f64::from(beam_scores.get(effective_beam_id));
let final_tokens = input_ids.get(effective_beam_id);
hypotheses[batch_index as usize].add(final_tokens, final_score);
}
batch_index += 1;
}
let (output_batch_size, output_num_return_sequences_per_batch) = if do_sample {
(batch_size, 1)
} else {
(batch_size * num_return_sequences, num_return_sequences)
};
let mut sentence_lengths =
Tensor::zeros(&[output_batch_size], (Int64, input_ids.device()));
let mut best_ids = vec![];
for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
let mut sorted_hypotheses = hypothesis.clone();
&sorted_hypotheses
.beams
.sort_by_key(|(score, _)| OrderedFloat(*score));
for j in 0..output_num_return_sequences_per_batch {
let effective_batch_index =
output_num_return_sequences_per_batch * hypothesis_index as i64 + j;
let (_, best_hyp) = sorted_hypotheses.beams.pop().unwrap();
let _ = sentence_lengths.index_fill_(
0,
&Tensor::of_slice(&[effective_batch_index]).to(sentence_lengths.device()),
*best_hyp.size().first().unwrap(),
);
best_ids.push(best_hyp);
}
}
let decoded = if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min())
{
let sentence_max_length = min(i64::from(sentence_lengths.max()) + 1, max_length);
let decoded: Tensor = Tensor::ones(
&[output_batch_size, sentence_max_length],
(Int64, input_ids.device()),
) * pad_token_id.unwrap();
for hypothesis_index in 0..best_ids.len() {
let _ = decoded.get(hypothesis_index as i64).index_copy_(
0,
&Tensor::arange1(
0,
i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()),
),
&best_ids[hypothesis_index],
);
let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64));
if sentence_length < max_length {
let _ = decoded.get(hypothesis_index as i64).index_fill_(
0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
eos_token_ids.as_ref().unwrap()[0],
);
}
}
decoded
} else {
Tensor::stack(&best_ids, 0)
.to_kind(Int64)
.to(input_ids.device())
};
decoded
}
fn reorder_cache(
&self,
past: &mut Cache,
_encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
match past {
Cache::None => None,
Cache::GPT2Cache(cached_decoder_state) => {
match cached_decoder_state {
Some(value) => {
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None
}
None => None,
}
}
Cache::BARTCache(_) => {
panic!("Not implemented");
}
}
}
}
}
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
PrivateLanguageGenerator<T, V, U>
{
fn generate(
&self,
prompt_texts: Option<Vec<&str>>,
attention_mask: Option<Tensor>,
) -> Vec<String> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
let max_length = config.max_length;
let encoding_max_len = if self.is_encoder_decoder() {
1024u64
} else {
max_length
};
let pad_token_id = match self.get_pad_id() {
Some(value) => Some(*value),
None => match &eos_token_ids {
Some(eos_ids) => Some(eos_ids[0]),
None => None,
},
};
let input_ids = match prompt_texts {
Some(text) => self.encode_prompt_text(text, encoding_max_len, pad_token_id),
None => match self.get_bos_id() {
Some(bos_id) => {
Tensor::ones(&[1, 1], (Int64, self.get_var_store().device())) * *bos_id
}
None => panic!(
"A model with a BOS token must be used to start generation with an empty input"
),
},
};
let generated = self.generate_from_ids_and_past(input_ids, attention_mask);
let mut output = Vec::with_capacity(generated.len());
for generated_sequence in generated {
output.push(self.get_tokenizer().decode(generated_sequence, true, true));
}
output
}
fn generate_from_ids_and_past(
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
let do_sample = config.do_sample;
let num_return_sequences = config.num_return_sequences;
let num_beams = config.num_beams;
let min_length = config.min_length;
let max_length = config.max_length;
let early_stopping = config.early_stopping;
let temperature = config.temperature;
let top_k = config.top_k;
let top_p = config.top_p;
let repetition_penalty = config.repetition_penalty;
let length_penalty = config.length_penalty;
let no_repeat_ngram_size = config.no_repeat_ngram_size;
let pad_token_id = match self.get_pad_id() {
Some(value) => Some(*value),
None => match &eos_token_ids {
Some(eos_ids) => Some(eos_ids[0]),
None => None,
},
};
let cur_len = if !self.is_encoder_decoder() {
*input_ids.size().last().unwrap()
} else {
1
};
let batch_size = *input_ids.size().first().unwrap();
let (effective_batch_size, effective_batch_mult) = match do_sample {
true => (
batch_size * num_return_sequences as i64,
num_return_sequences as i64,
),
false => (batch_size, 1),
};
let attention_mask = match attention_mask {
Some(value) => value,
None => match self.get_pad_id() {
Some(pad_id) => input_ids.ne(*pad_id).to_kind(Int64),
None => input_ids.ones_like(),
},
};
let encoder_outputs = if self.is_encoder_decoder() {
let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap();
let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
.view((-1, 1))
.repeat(&[1, num_beams as i64 * effective_batch_mult])
.view(-1);
Some(encoder_outputs.index_select(0, &expanded_batch_indices))
} else {
None
};
let (input_ids, attention_mask) = if !self.is_encoder_decoder() {
if (num_return_sequences > 1) | (num_beams > 1) {
(
input_ids
.unsqueeze(1)
.expand(
&[batch_size, effective_batch_mult * num_beams as i64, cur_len],
true,
)
.contiguous()
.view((effective_batch_size * num_beams as i64, cur_len)),
attention_mask
.unsqueeze(1)
.expand(
&[batch_size, effective_batch_mult * num_beams as i64, cur_len],
true,
)
.contiguous()
.view((effective_batch_size * num_beams as i64, cur_len)),
)
} else {
(input_ids, attention_mask)
}
} else {
let decoder_start_token_id = self
.get_decoder_start_id()
.expect("decoder start id must be specified for encoder decoders");
let input_ids = Tensor::full(
&[effective_batch_size * num_beams as i64, 1],
decoder_start_token_id,
(Int64, input_ids.device()),
);
(input_ids, attention_mask)
};
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(
input_ids,
encoder_outputs,
cur_len,
min_length as i64,
max_length as i64,
do_sample,
early_stopping,
temperature,
top_k as i64,
top_p,
repetition_penalty,
no_repeat_ngram_size as i64,
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences as i64,
length_penalty,
num_beams as i64,
attention_mask,
)
} else {
self.generate_no_beam_search(
input_ids,
encoder_outputs,
cur_len,
min_length as i64,
max_length as i64,
do_sample,
temperature,
top_k as i64,
top_p,
repetition_penalty,
no_repeat_ngram_size as i64,
pad_token_id,
eos_token_ids,
effective_batch_size,
attention_mask,
)
}
});
let num_sequences = *decoded.size().first().unwrap();
let mut output_ids = Vec::with_capacity(num_sequences as usize);
for sequence_index in 0..num_sequences {
let sequence_output_ids = decoded
.as_ref()
.get(sequence_index)
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
output_ids.push(sequence_output_ids.clone());
}
output_ids
}
}
#[derive(Debug)]
struct BeamHypotheses {
max_length: i64,
length_penalty: f64,
early_stopping: bool,
num_beams: i64,
beams: Vec<(f64, Tensor)>,
worst_score: f64,
}
impl Clone for BeamHypotheses {
fn clone(&self) -> Self {
BeamHypotheses {
max_length: self.max_length,
length_penalty: self.length_penalty,
early_stopping: self.early_stopping,
num_beams: self.num_beams,
beams: self
.beams
.iter()
.map(|(score, tensor)| (*score, tensor.copy()))
.collect_vec(),
worst_score: self.worst_score,
}
}
}
impl BeamHypotheses {
fn new(
num_beams: i64,
max_length: i64,
length_penalty: f64,
early_stopping: bool,
) -> BeamHypotheses {
BeamHypotheses {
max_length: max_length - 1,
length_penalty,
early_stopping,
num_beams,
beams: Vec::with_capacity(num_beams as usize + 1),
worst_score: std::f64::INFINITY,
}
}
fn len(&self) -> i64 {
self.beams.len() as i64
}
fn add(&mut self, hypothesis: Tensor, sum_log_probabilities: f64) {
let score = sum_log_probabilities
/ ((*hypothesis.size().first().unwrap() as f64).powf(self.length_penalty));
if (self.len() < self.num_beams) | (score > self.worst_score) {
self.beams.push((score, hypothesis));
if self.len() > self.num_beams {
let (worst_score_position, _) = self
.beams
.iter()
.enumerate()
.min_by_key(|(_, (score, _))| OrderedFloat(*score))
.unwrap();
let _ = self.beams.remove(worst_score_position);
}
self.worst_score = self
.beams
.iter()
.min_by_key(|(score, _)| OrderedFloat(*score))
.unwrap()
.0;
}
}
fn is_done(&self, best_sum_log_probabilities: f64, current_length: i64) -> bool {
if self.len() < self.num_beams {
false
} else if self.early_stopping {
true
} else {
self.worst_score
>= best_sum_log_probabilities / (current_length as f64).powf(self.length_penalty)
}
}
}
pub trait LMHeadModel {
fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
>;
}