use crate::common::activations::{Activation, TensorFunction};
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::mobilebert::embeddings::MobileBertEmbeddings;
use crate::mobilebert::encoder::{MobileBertEncoder, MobileBertPooler};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::nn::{Init, LayerNormConfig, Module};
use tch::{nn, Kind, Tensor};
pub struct MobileBertModelResources;
pub struct MobileBertConfigResources;
pub struct MobileBertVocabResources;
impl MobileBertModelResources {
pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
"mobilebert-uncased/model",
"https://huggingface.co/google/mobilebert-uncased/resolve/main/rust_model.ot",
);
pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
"mobilebert-finetuned-pos/model",
"https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/rust_model.ot",
);
}
impl MobileBertConfigResources {
pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
"mobilebert-uncased/config",
"https://huggingface.co/google/mobilebert-uncased/resolve/main/config.json",
);
pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
"mobilebert-finetuned-pos/config",
"https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/config.json",
);
}
impl MobileBertVocabResources {
pub const MOBILEBERT_UNCASED: (&'static str, &'static str) = (
"mobilebert-uncased/vocab",
"https://huggingface.co/google/mobilebert-uncased/resolve/main/vocab.txt",
);
pub const MOBILEBERT_ENGLISH_POS: (&'static str, &'static str) = (
"mobilebert-finetuned-pos/vocab",
"https://huggingface.co/mrm8488/mobilebert-finetuned-pos/resolve/main/vocab.txt",
);
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
pub enum NormalizationType {
layer_norm,
no_norm,
}
#[derive(Debug)]
pub struct NoNorm {
weight: Tensor,
bias: Tensor,
}
impl NoNorm {
pub fn new<'p, P>(p: P, hidden_size: i64) -> NoNorm
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let weight = p.var("weight", &[hidden_size], Init::Const(1.0));
let bias = p.var("bias", &[hidden_size], Init::Const(0.0));
NoNorm { weight, bias }
}
}
impl Module for NoNorm {
fn forward(&self, xs: &Tensor) -> Tensor {
xs * &self.weight + &self.bias
}
}
pub enum NormalizationLayer {
LayerNorm(nn::LayerNorm),
NoNorm(NoNorm),
}
impl NormalizationLayer {
pub fn new<'p, P>(
p: P,
normalization_type: NormalizationType,
hidden_size: i64,
eps: Option<f64>,
) -> NormalizationLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
match normalization_type {
NormalizationType::layer_norm => {
let layer_norm_config = LayerNormConfig {
eps: eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm = nn::layer_norm(p, vec![hidden_size], layer_norm_config);
NormalizationLayer::LayerNorm(layer_norm)
}
NormalizationType::no_norm => {
let layer_norm = NoNorm::new(p, hidden_size);
NormalizationLayer::NoNorm(layer_norm)
}
}
}
pub fn forward(&self, input: &Tensor) -> Tensor {
match self {
NormalizationLayer::LayerNorm(ref layer_norm) => input.apply(layer_norm),
NormalizationLayer::NoNorm(ref layer_norm) => input.apply(layer_norm),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MobileBertConfig {
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: f64,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub num_attention_heads: i64,
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub embedding_size: i64,
pub layer_norm_eps: Option<f64>,
pub pad_token_idx: Option<i64>,
pub trigram_input: Option<bool>,
pub use_bottleneck: Option<bool>,
pub use_bottleneck_attention: Option<bool>,
pub intra_bottleneck_size: Option<i64>,
pub key_query_shared_bottleneck: Option<bool>,
pub num_feedforward_networks: Option<i64>,
pub normalization_type: Option<NormalizationType>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub classifier_activation: Option<bool>,
pub is_decoder: Option<bool>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
}
impl Config for MobileBertConfig {}
impl Default for MobileBertConfig {
fn default() -> Self {
MobileBertConfig {
hidden_act: Activation::relu,
attention_probs_dropout_prob: 0.1,
hidden_dropout_prob: 0.0,
hidden_size: 512,
initializer_range: 0.02,
intermediate_size: 512,
max_position_embeddings: 512,
num_attention_heads: 4,
num_hidden_layers: 24,
type_vocab_size: 2,
vocab_size: 30522,
embedding_size: 128,
layer_norm_eps: Some(1e-12),
pad_token_idx: Some(0),
trigram_input: Some(true),
use_bottleneck: Some(true),
use_bottleneck_attention: Some(false),
intra_bottleneck_size: Some(128),
key_query_shared_bottleneck: Some(true),
num_feedforward_networks: Some(4),
normalization_type: Some(NormalizationType::no_norm),
output_attentions: None,
output_hidden_states: None,
classifier_activation: None,
is_decoder: None,
id2label: None,
label2id: None,
}
}
}
pub struct MobileBertPredictionHeadTransform {
dense: nn::Linear,
activation_function: TensorFunction,
layer_norm: NormalizationLayer,
}
impl MobileBertPredictionHeadTransform {
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertPredictionHeadTransform
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let activation_function = config.hidden_act.get_function();
let layer_norm = NormalizationLayer::new(
p / "LayerNorm",
NormalizationType::layer_norm,
config.hidden_size,
config.layer_norm_eps,
);
MobileBertPredictionHeadTransform {
dense,
activation_function,
layer_norm,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
let hidden_states = hidden_states.apply(&self.dense);
let hidden_states = self.activation_function.get_fn()(&hidden_states);
self.layer_norm.forward(&hidden_states)
}
}
pub struct MobileBertLMPredictionHead {
transform: MobileBertPredictionHeadTransform,
dense_weight: Tensor,
bias: Tensor,
}
impl MobileBertLMPredictionHead {
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertLMPredictionHead
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transform = MobileBertPredictionHeadTransform::new(p / "transform", config);
let dense_p = p / "dense";
let dense_weight = dense_p.var(
"weight",
&[
config.hidden_size - config.embedding_size,
config.vocab_size,
],
DEFAULT_KAIMING_UNIFORM,
);
let bias = p.var("bias", &[config.vocab_size], Init::Const(0.0));
MobileBertLMPredictionHead {
transform,
dense_weight,
bias,
}
}
pub fn forward(&self, hidden_states: &Tensor, embeddings: &Tensor) -> Tensor {
let hidden_states = self.transform.forward(hidden_states);
let hidden_states = hidden_states.matmul(&Tensor::cat(
&[&embeddings.transpose(0, 1), &self.dense_weight],
0,
));
hidden_states + &self.bias
}
}
pub struct MobileBertOnlyMLMHead {
predictions: MobileBertLMPredictionHead,
}
impl MobileBertOnlyMLMHead {
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertOnlyMLMHead
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let predictions = MobileBertLMPredictionHead::new(p / "predictions", config);
MobileBertOnlyMLMHead { predictions }
}
pub fn forward(&self, hidden_states: &Tensor, embeddings: &Tensor) -> Tensor {
self.predictions.forward(hidden_states, embeddings)
}
}
pub struct MobileBertModel {
embeddings: MobileBertEmbeddings,
encoder: MobileBertEncoder,
pooler: Option<MobileBertPooler>,
position_ids: Tensor,
}
impl MobileBertModel {
pub fn new<'p, P>(p: P, config: &MobileBertConfig, add_pooling_layer: bool) -> MobileBertModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let embeddings = MobileBertEmbeddings::new(p / "embeddings", config);
let encoder = MobileBertEncoder::new(p / "encoder", config);
let pooler = if add_pooling_layer {
Some(MobileBertPooler::new(p / "pooler", config))
} else {
None
};
let position_ids =
Tensor::arange(config.max_position_embeddings, (Kind::Int64, p.device()))
.expand([1, -1], true);
MobileBertModel {
embeddings,
encoder,
pooler,
position_ids,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<MobileBertOutput, RustBertError> {
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let calc_attention_mask = if attention_mask.is_none() {
Some(Tensor::ones(input_shape.as_slice(), (Kind::Int64, device)))
} else {
None
};
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(input_shape.as_slice(), (Kind::Int64, device)))
} else {
None
};
let calc_position_ids = if position_ids.is_none() {
Some(self.position_ids.slice(1, 0, input_shape[1], 1))
} else {
None
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let attention_mask =
attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
let attention_mask = match attention_mask.dim() {
3 => attention_mask.unsqueeze(1),
2 => attention_mask.unsqueeze(1).unsqueeze(1),
_ => {
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
let token_type_ids =
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
let embedding_output = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let attention_mask: Tensor = ((attention_mask.ones_like() - attention_mask) * -10000.0)
.to_kind(embedding_output.kind());
let encoder_output =
self.encoder
.forward_t(&embedding_output, Some(&attention_mask), train);
let pooled_output = if let Some(pooler) = &self.pooler {
Some(pooler.forward(&encoder_output.hidden_state))
} else {
None
};
Ok(MobileBertOutput {
hidden_state: encoder_output.hidden_state,
pooled_output,
all_hidden_states: encoder_output.all_hidden_states,
all_attentions: encoder_output.all_attentions,
})
}
fn get_embeddings(&self) -> &Tensor {
&self.embeddings.word_embeddings.ws
}
}
pub struct MobileBertForMaskedLM {
mobilebert: MobileBertModel,
classifier: MobileBertOnlyMLMHead,
}
impl MobileBertForMaskedLM {
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForMaskedLM
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
let classifier = MobileBertOnlyMLMHead::new(p / "cls", config);
MobileBertForMaskedLM {
mobilebert,
classifier,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<MobileBertMaskedLMOutput, RustBertError> {
let mobilebert_output = self.mobilebert.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
attention_mask,
train,
)?;
let logits = self.classifier.forward(
&mobilebert_output.hidden_state,
self.mobilebert.get_embeddings(),
);
Ok(MobileBertMaskedLMOutput {
logits,
all_hidden_states: mobilebert_output.all_hidden_states,
all_attentions: mobilebert_output.all_attentions,
})
}
}
pub struct MobileBertForSequenceClassification {
mobilebert: MobileBertModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl MobileBertForSequenceClassification {
pub fn new<'p, P>(
p: P,
config: &MobileBertConfig,
) -> Result<MobileBertForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let mobilebert = MobileBertModel::new(p / "mobilebert", config, true);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(MobileBertForSequenceClassification {
mobilebert,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<MobileBertSequenceClassificationOutput, RustBertError> {
let mobilebert_output = self.mobilebert.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
attention_mask,
train,
)?;
let logits = mobilebert_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(MobileBertSequenceClassificationOutput {
logits,
all_hidden_states: mobilebert_output.all_hidden_states,
all_attentions: mobilebert_output.all_attentions,
})
}
}
pub struct MobileBertForQuestionAnswering {
mobilebert: MobileBertModel,
qa_outputs: nn::Linear,
}
impl MobileBertForQuestionAnswering {
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForQuestionAnswering
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, 2, Default::default());
MobileBertForQuestionAnswering {
mobilebert,
qa_outputs,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<MobileBertQuestionAnsweringOutput, RustBertError> {
let mobilebert_output = self.mobilebert.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
attention_mask,
train,
)?;
let sequence_output = mobilebert_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
Ok(MobileBertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: mobilebert_output.all_hidden_states,
all_attentions: mobilebert_output.all_attentions,
})
}
}
pub struct MobileBertForMultipleChoice {
mobilebert: MobileBertModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl MobileBertForMultipleChoice {
pub fn new<'p, P>(p: P, config: &MobileBertConfig) -> MobileBertForMultipleChoice
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let mobilebert = MobileBertModel::new(p / "mobilebert", config, true);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
MobileBertForMultipleChoice {
mobilebert,
dropout,
classifier,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<MobileBertSequenceClassificationOutput, RustBertError> {
let (input_ids, num_choices) = match input_ids {
Some(value) => (
Some(value.view((-1, *value.size().last().unwrap()))),
value.size()[1],
),
None => (
None,
input_embeds
.as_ref()
.expect("At least one of input ids or input_embeds must be provided")
.size()[1],
),
};
let attention_mask =
attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let token_type_ids =
token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let input_embeds =
input_embeds.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let position_ids =
position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let mobilebert_output = self.mobilebert.forward_t(
input_ids.as_ref(),
token_type_ids.as_ref(),
position_ids.as_ref(),
input_embeds.as_ref(),
attention_mask.as_ref(),
train,
)?;
let logits = mobilebert_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view([-1, num_choices]);
Ok(MobileBertSequenceClassificationOutput {
logits,
all_hidden_states: mobilebert_output.all_hidden_states,
all_attentions: mobilebert_output.all_attentions,
})
}
}
pub struct MobileBertForTokenClassification {
mobilebert: MobileBertModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl MobileBertForTokenClassification {
pub fn new<'p, P>(
p: P,
config: &MobileBertConfig,
) -> Result<MobileBertForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let mobilebert = MobileBertModel::new(p / "mobilebert", config, false);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(MobileBertForTokenClassification {
mobilebert,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool,
) -> Result<MobileBertTokenClassificationOutput, RustBertError> {
let mobilebert_output = self.mobilebert.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
attention_mask,
train,
)?;
let logits = mobilebert_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(MobileBertTokenClassificationOutput {
logits,
all_hidden_states: mobilebert_output.all_hidden_states,
all_attentions: mobilebert_output.all_attentions,
})
}
}
pub struct MobileBertOutput {
pub hidden_state: Tensor,
pub pooled_output: Option<Tensor>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct MobileBertMaskedLMOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct MobileBertSequenceClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct MobileBertTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct MobileBertQuestionAnsweringOutput {
pub start_logits: Tensor,
pub end_logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}