extern crate tch;
use self::tch::{nn, Tensor};
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::{DistilBertTransformerOutput, Transformer};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
pub struct DistilBertModelResources;
pub struct DistilBertConfigResources;
pub struct DistilBertVocabResources;
impl DistilBertModelResources {
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/model",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/rust_model.ot",
);
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/model",
"https://huggingface.co/distilbert-base-uncased/resolve/main/rust_model.ot",
);
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/model",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/rust_model.ot",
);
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/model",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/rust_model.ot",
);
}
impl DistilBertConfigResources {
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/config",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/config.json",
);
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/config",
"https://huggingface.co/distilbert-base-uncased/resolve/main/config.json",
);
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/config",
"https://huggingface.co/distilbert-base-cased-distilled-squad/resolve/main/config.json",
);
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/config",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/config.json",
);
}
impl DistilBertVocabResources {
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/vocab",
"https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/vocab.txt",
);
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/vocab",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/vocab",
"https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
);
pub const DISTILUSE_BASE_MULTILINGUAL_CASED: (&'static str, &'static str) = (
"distiluse-base-multilingual-cased/vocab",
"https://huggingface.co/sentence-transformers/distiluse-base-multilingual-cased/resolve/main/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DistilBertConfig {
pub activation: Activation,
pub attention_dropout: f64,
pub dim: i64,
pub dropout: f64,
pub hidden_dim: i64,
pub id2label: Option<HashMap<i64, String>>,
pub initializer_range: f32,
pub is_decoder: Option<bool>,
pub label2id: Option<HashMap<String, i64>>,
pub max_position_embeddings: i64,
pub n_heads: i64,
pub n_layers: i64,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub output_past: Option<bool>,
pub qa_dropout: f64,
pub seq_classif_dropout: f64,
pub sinusoidal_pos_embds: bool,
pub tie_weights_: bool,
pub vocab_size: i64,
}
impl Config for DistilBertConfig {}
impl Default for DistilBertConfig {
fn default() -> Self {
DistilBertConfig {
activation: Activation::gelu,
attention_dropout: 0.1,
dim: 768,
dropout: 0.1,
hidden_dim: 3072,
id2label: None,
initializer_range: 0.02,
is_decoder: None,
label2id: None,
max_position_embeddings: 512,
n_heads: 12,
n_layers: 6,
output_attentions: None,
output_hidden_states: None,
output_past: None,
qa_dropout: 0.1,
seq_classif_dropout: 0.2,
sinusoidal_pos_embds: false,
tie_weights_: false,
vocab_size: 30522,
}
}
}
pub struct DistilBertModel {
embeddings: DistilBertEmbedding,
transformer: Transformer,
}
impl DistilBertModel {
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "distilbert";
let embeddings = DistilBertEmbedding::new(&p / "embeddings", config);
let transformer = Transformer::new(p / "transformer", config);
DistilBertModel {
embeddings,
transformer,
}
}
pub fn forward_t(
&self,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertTransformerOutput, RustBertError> {
let input_embeddings = self.embeddings.forward_t(input, input_embeds, train)?;
let transformer_output = self.transformer.forward_t(&input_embeddings, mask, train);
Ok(transformer_output)
}
}
pub struct DistilBertModelClassifier {
distil_bert_model: DistilBertModel,
pre_classifier: nn::Linear,
classifier: nn::Linear,
dropout: Dropout,
}
impl DistilBertModelClassifier {
pub fn new<'p, P>(
p: P,
config: &DistilBertConfig,
) -> Result<DistilBertModelClassifier, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let distil_bert_model = DistilBertModel::new(p, config);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let pre_classifier = nn::linear(
p / "pre_classifier",
config.dim,
config.dim,
Default::default(),
);
let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);
Ok(DistilBertModelClassifier {
distil_bert_model,
pre_classifier,
classifier,
dropout,
})
}
pub fn forward_t(
&self,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertSequenceClassificationOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let logits = base_model_output
.hidden_state
.select(1, 0)
.apply(&self.pre_classifier)
.relu()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(DistilBertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub struct DistilBertModelMaskedLM {
distil_bert_model: DistilBertModel,
vocab_transform: nn::Linear,
vocab_layer_norm: nn::LayerNorm,
vocab_projector: nn::Linear,
}
impl DistilBertModelMaskedLM {
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelMaskedLM
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let distil_bert_model = DistilBertModel::new(p, config);
let vocab_transform = nn::linear(
p / "vocab_transform",
config.dim,
config.dim,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let vocab_layer_norm =
nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(
p / "vocab_projector",
config.dim,
config.vocab_size,
Default::default(),
);
DistilBertModelMaskedLM {
distil_bert_model,
vocab_transform,
vocab_layer_norm,
vocab_projector,
}
}
pub fn forward_t(
&self,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertMaskedLMOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let prediction_scores = base_model_output
.hidden_state
.apply(&self.vocab_transform)
.gelu("none")
.apply(&self.vocab_layer_norm)
.apply(&self.vocab_projector);
Ok(DistilBertMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub struct DistilBertForQuestionAnswering {
distil_bert_model: DistilBertModel,
qa_outputs: nn::Linear,
dropout: Dropout,
}
impl DistilBertForQuestionAnswering {
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForQuestionAnswering
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let distil_bert_model = DistilBertModel::new(p, config);
let qa_outputs = nn::linear(p / "qa_outputs", config.dim, 2, Default::default());
let dropout = Dropout::new(config.qa_dropout);
DistilBertForQuestionAnswering {
distil_bert_model,
qa_outputs,
dropout,
}
}
pub fn forward_t(
&self,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertQuestionAnsweringOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let output = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
Ok(DistilBertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub struct DistilBertForTokenClassification {
distil_bert_model: DistilBertModel,
classifier: nn::Linear,
dropout: Dropout,
}
impl DistilBertForTokenClassification {
pub fn new<'p, P>(
p: P,
config: &DistilBertConfig,
) -> Result<DistilBertForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let distil_bert_model = DistilBertModel::new(p, config);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"id2label must be provided for classifiers".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);
Ok(DistilBertForTokenClassification {
distil_bert_model,
classifier,
dropout,
})
}
pub fn forward_t(
&self,
input: Option<&Tensor>,
mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DistilBertTokenClassificationOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(DistilBertTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub type DistilBertForSentenceEmbeddings = DistilBertModel;
pub struct DistilBertMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct DistilBertSequenceClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct DistilBertTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct DistilBertQuestionAnsweringOutput {
pub start_logits: Tensor,
pub end_logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}