// Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::m2m_100::decoder::M2M100Decoder;
use crate::m2m_100::encoder::M2M100Encoder;
use crate::m2m_100::LayerState;
use crate::mbart::{MBartConfig, MBartModelOutput};
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::translation::Language;
use crate::resources::{RemoteResource, Resource};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{M2M100Vocab, Vocab};
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # M2M100 Pretrained model weight files
pub struct M2M100ModelResources;
/// # M2M100 Pretrained model config files
pub struct M2M100ConfigResources;
/// # M2M100 Pretrained model vocab files
pub struct M2M100VocabResources;
/// # M2M100 Pretrained model merges files
pub struct M2M100MergesResources;
/// # M2M100 source languages pre-sets
pub struct M2M100SourceLanguages;
/// # M2M100 target languages pre-sets
pub type M2M100TargetLanguages = M2M100SourceLanguages;
impl M2M100ModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/model",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/model",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/rust_model.ot",
);
}
impl M2M100ConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/config",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/config",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/config.json",
);
}
impl M2M100VocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/vocab",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/vocab",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/vocab.json",
);
}
impl M2M100MergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_418M: (&'static str, &'static str) = (
"m2m100-418m/merges",
"https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at <https://github.com/pytorch/fairseq>. Modified with conversion to C-array format.
pub const M2M100_1_2B: (&'static str, &'static str) = (
"m2m100-1_2b/merges",
"https://huggingface.co/facebook/m2m100_1.2B/resolve/main/sentencepiece.bpe.model",
);
}
#[rustfmt::skip]
impl M2M100SourceLanguages {
pub const M2M100_418M: [Language; 100] = [Language::Afrikaans, Language::Danish, Language::Dutch, Language::German, Language::English, Language::Icelandic, Language::Luxembourgish, Language::Norwegian, Language::Swedish, Language::WesternFrisian, Language::Yiddish, Language::Asturian, Language::Catalan, Language::French, Language::Galician, Language::Italian, Language::Occitan, Language::Portuguese, Language::Romanian, Language::Spanish, Language::Belarusian, Language::Bosnian, Language::Bulgarian, Language::Croatian, Language::Czech, Language::Macedonian, Language::Polish, Language::Russian, Language::Serbian, Language::Slovak, Language::Slovenian, Language::Ukrainian, Language::Estonian, Language::Finnish, Language::Hungarian, Language::Latvian, Language::Lithuanian, Language::Albanian, Language::Armenian, Language::Georgian, Language::Greek, Language::Breton, Language::Irish, Language::ScottishGaelic, Language::Welsh, Language::Azerbaijani, Language::Bashkir, Language::Kazakh, Language::Turkish, Language::Uzbek, Language::Japanese, Language::Korean, Language::Vietnamese, Language::ChineseMandarin, Language::Bengali, Language::Gujarati, Language::Hindi, Language::Kannada, Language::Marathi, Language::Nepali, Language::Oriya, Language::Panjabi, Language::Sindhi, Language::Sinhala, Language::Urdu, Language::Tamil, Language::Cebuano, Language::Iloko, Language::Indonesian, Language::Javanese, Language::Malagasy, Language::Malay, Language::Malayalam, Language::Sundanese, Language::Tagalog, Language::Burmese, Language::CentralKhmer, Language::Lao, Language::Thai, Language::Mongolian, Language::Arabic, Language::Hebrew, Language::Pashto, Language::Farsi, Language::Amharic, Language::Fulah, Language::Hausa, Language::Igbo, Language::Lingala, Language::Luganda, Language::NorthernSotho, Language::Somali, Language::Swahili, Language::Swati, Language::Tswana, Language::Wolof, Language::Xhosa, Language::Yoruba, Language::Zulu, Language::HaitianCreole];
pub const M2M100_1_2B: [Language; 100] = M2M100SourceLanguages::M2M100_418M;
}
pub type M2M100Config = MBartConfig;
fn _shift_tokens_right(
input_ids: &Tensor,
pad_token_id: i64,
decoder_start_token_id: i64,
) -> Tensor {
let shifted_input_ids = Tensor::zeros(
input_ids.size().as_slice(),
(Kind::Int64, input_ids.device()),
);
let _ = shifted_input_ids.select(1, 0).fill_(decoder_start_token_id);
let _ = shifted_input_ids
.slice(1, 1, *shifted_input_ids.size().last().unwrap(), 1)
.copy_(&input_ids.slice(1, 0, *input_ids.size().last().unwrap() - 1, 1));
shifted_input_ids.masked_fill(&shifted_input_ids.eq(-100), pad_token_id)
}
/// # M2M100 Base model
/// Base architecture for M2M100 model. Usually complemented with a task-specific head, such as a language model head.
/// It is made of the following blocks:
/// - `encoder`: `M2M100Encoder` (transformer) made of a vector of encoding layers
/// - `decoder`: `M2M100Decoder` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention.
/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
/// - `pad_token_id`: padding token id
pub struct M2M100Model {
pub(crate) encoder: M2M100Encoder,
decoder: M2M100Decoder,
pub(crate) embeddings: nn::Embedding,
pad_token_id: i64,
decoder_start_token_id: i64,
}
impl M2M100Model {
/// Build a new `M2M100Model`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the M2M100 model
/// * `config` - `M2M100Config` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::m2m_100::{M2M100Config, M2M100Model};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = M2M100Config::from_file(config_path);
/// let m2m100: M2M100Model = M2M100Model::new(&p.root() / "m2m100", &config);
/// ```
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100Model
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let pad_token_id = config.pad_token_id.unwrap_or(1);
let decoder_start_token_id = config.decoder_start_token_id.unwrap_or(2);
let embedding_config = EmbeddingConfig {
padding_idx: pad_token_id,
..Default::default()
};
let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
embedding_config,
);
let encoder = M2M100Encoder::new(p / "encoder", config);
let decoder = M2M100Decoder::new(p / "decoder", config);
M2M100Model {
encoder,
decoder,
embeddings,
pad_token_id,
decoder_start_token_id,
}
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `M2M100ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// - `encoder_hidden_states` - `Option<Tensor>` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state if it was not provided, otherwise None
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::m2m_100::{M2M100Config, M2M100Model};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100Model = M2M100Model::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// m2m100_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_output: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> M2M100ModelOutput {
let calc_decoder_input_ids = if decoder_input_ids.is_none() {
Some(_shift_tokens_right(
input_ids.unwrap(),
self.pad_token_id,
self.decoder_start_token_id,
))
} else {
None
};
let decoder_input_ids =
decoder_input_ids.unwrap_or_else(|| calc_decoder_input_ids.as_ref().unwrap());
let calc_encoder_output = if encoder_output.is_none() {
Some(self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
))
} else {
None
};
let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
if let Some(calc_encoder_output) = calc_encoder_output {
(
Some(calc_encoder_output.hidden_state),
calc_encoder_output.all_hidden_states,
calc_encoder_output.all_attentions,
)
} else {
(None, None, None)
};
let encoder_output = encoder_output.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
let decoder_output = self.decoder.forward_t(
decoder_input_ids,
encoder_output,
attention_mask,
decoder_attention_mask,
&self.embeddings,
layer_states,
train,
);
M2M100ModelOutput {
decoder_output: decoder_output.hidden_state,
encoder_hidden_state: calc_hidden_states,
cache: decoder_output.next_decoder_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
all_decoder_attentions: decoder_output.all_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
}
}
}
/// Container holding a M2M100 model output
pub type M2M100ModelOutput = MBartModelOutput;
/// # M2M100 Model for conditional generation
/// M2M100 model with a vocabulary decoding head
/// It is made of the following blocks:
/// - `base_model`: `M2M100Model` Base M2M100 model
/// - `linear`: Linear layer without bias tied to the weights of the token id embeddings
pub struct M2M100ForConditionalGeneration {
base_model: M2M100Model,
}
impl M2M100ForConditionalGeneration {
/// Build a new `M2M100ForConditionalGeneration`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the M2M100 model
/// * `config` - `M2M100Config` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = M2M100Config::from_file(config_path);
/// let m2m100: M2M100ForConditionalGeneration =
/// M2M100ForConditionalGeneration::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &M2M100Config) -> M2M100ForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let base_model = M2M100Model::new(p.borrow() / "model", config);
M2M100ForConditionalGeneration { base_model }
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). Must be provided when not running in generation mode
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `M2M100ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// # use rust_bert::m2m_100::{M2M100Config, M2M100ForConditionalGeneration};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// m2m100_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_output: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> M2M100ModelOutput {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_output,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
M2M100ModelOutput {
decoder_output: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
self.base_model
.encoder
.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
)
.hidden_state
}
}
impl LMHeadModel for M2M100ForConditionalGeneration {
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for M2M100
/// * `token_type_ids` - Unused for M2M100
/// * `position_ids` - Unused for M2M100
/// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BARTCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::pipelines::generation_utils::LMHeadModel;
/// use rust_bert::m2m_100::{M2M100ForConditionalGeneration, M2M100Config};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = M2M100Config::from_file(config_path);
/// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// m2m100_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ```
fn forward_t(
&self,
input_ids: Option<&Tensor>,
cache: Cache,
attention_mask: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
_position_ids: Option<&Tensor>,
_input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
None,
cached_layer_states,
train,
),
Cache::None => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
None,
None,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with M2M100 Model".into(),
));
}
};
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok(LMModelOutput {
lm_logits,
cache: Cache::BARTCache(base_model_output.cache),
})
}
}
/// # Language generation model based on the M2M100 architecture
pub struct M2M100Generator {
model: M2M100ForConditionalGeneration,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl M2M100Generator {
/// Build a new `M2M100Generator`
///
/// # Arguments
///
/// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `merges_path` - Path to the bpe merges, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
/// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
///
/// # Example
///
/// ```no_run
/// # use std::path::PathBuf;
/// # use tch::Device;
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::m2m_100::M2M100Generator;
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
/// # let mut home: PathBuf = dirs::home_dir().unwrap();
/// # home.push("rustbert");
/// # home.push("openai-gpt");
/// # let config_path = &home.as_path().join("config.json");
/// # let vocab_path = &home.as_path().join("vocab.txt");
/// # let merges_path = &home.as_path().join("merges.txt");
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let m2m100_generator = M2M100Generator::new(generate_config)?;
/// # Ok(())
/// # }
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100ModelResources::M2M100_418M,
))
} else {
generate_config.model_resource.clone()
};
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100VocabResources::M2M100_418M,
))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
M2M100MergesResources::M2M100_418M,
))
} else {
generate_config.merges_resource.clone()
};
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = TokenizerOption::from_file(
ModelType::M2M100,
vocab_path.to_str().unwrap(),
Some(merges_path.to_str().unwrap()),
false,
None,
None,
)?;
let config = M2M100Config::from_file(config_path);
let model = M2M100ForConditionalGeneration::new(&var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(0);
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec![value],
None => vec![2],
});
let pad_token_id = Some(config.pad_token_id.unwrap_or(1));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
let max_position_embeddings = config.max_position_embeddings;
Ok(M2M100Generator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, f64::NEG_INFINITY);
}
}
impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M100Tokenizer>
for M2M100Generator
{
fn get_model(&self) -> &M2M100ForConditionalGeneration {
&self.model
}
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
match past {
Cache::BARTCache(past) => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
prepared_position_ids: None,
prepared_past: Cache::BARTCache(past),
},
Cache::None => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids),
prepared_position_ids: None,
prepared_past: Cache::BARTCache(None),
},
_ => panic!("Cache type incompatible with M2M100"),
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self
._get_tokenizer()
.convert_tokens_to_ids(&[M2M100Vocab::unknown_value()])[0],
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
let encoder_outputs = encoder_outputs.map(|value| value.index_select(0, beam_indices));
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
}
}
None => {}
},
Cache::None => {}
_ => {
panic!("Invalid cache for M2M100 model");
}
};
encoder_outputs
}
}
impl LanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M100Tokenizer>
for M2M100Generator
{
}
#[cfg(test)]
mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, Resource},
Config,
};
use super::*;
#[test]
#[ignore] // compilation is enough, no need to run
fn mbart_model_send() {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
M2M100ConfigResources::M2M100_418M,
));
let config_path = config_resource.get_local_path().expect("");
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = tch::nn::VarStore::new(device);
let config = M2M100Config::from_file(config_path);
let _: Box<dyn Send> = Box::new(M2M100Model::new(&vs.root(), &config));
}
}