extern crate tch;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use self::tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::Config;
pub struct DistilBertModelResources;
pub struct DistilBertConfigResources;
pub struct DistilBertVocabResources;
impl DistilBertModelResources {
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot");
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/model.ot", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot");
}
impl DistilBertConfigResources {
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json");
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-config.json");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/config.json", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json");
}
impl DistilBertVocabResources {
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/vocab.txt", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt");
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/vocab.txt", "https://cdn.huggingface.co/bert-large-cased-vocab.txt");
}
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
pub enum Activation {
gelu,
relu,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DistilBertConfig {
pub activation: Activation,
pub attention_dropout: f64,
pub dim: i64,
pub dropout: f64,
pub hidden_dim: i64,
pub id2label: Option<HashMap<i32, String>>,
pub initializer_range: f32,
pub is_decoder: Option<bool>,
pub label2id: Option<HashMap<String, i32>>,
pub max_position_embeddings: i64,
pub n_heads: i64,
pub n_layers: i64,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub output_past: Option<bool>,
pub qa_dropout: f64,
pub seq_classif_dropout: f64,
pub sinusoidal_pos_embds: bool,
pub tie_weights_: bool,
pub torchscript: Option<bool>,
pub use_bfloat16: Option<bool>,
pub vocab_size: i64,
}
impl Config<DistilBertConfig> for DistilBertConfig {}
pub struct DistilBertModel {
embeddings: DistilBertEmbedding,
transformer: Transformer,
}
impl DistilBertModel {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
let p = &(p / "distilbert");
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
let transformer = Transformer::new(&(p / "transformer"), config);
DistilBertModel { embeddings, transformer }
}
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let input_embeddings = match input {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => input_value.apply_t(&self.embeddings, train)
}
None => match input_embeds {
Some(embeds) => embeds,
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
};
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
Ok(transformer_output)
}
}
pub struct DistilBertModelClassifier {
distil_bert_model: DistilBertModel,
pre_classifier: nn::Linear,
classifier: nn::Linear,
dropout: Dropout,
}
impl DistilBertModelClassifier {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default());
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout }
}
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
let output = output
.select(1, 0)
.apply(&self.pre_classifier)
.relu()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
}
}
pub struct DistilBertModelMaskedLM {
distil_bert_model: DistilBertModel,
vocab_transform: nn::Linear,
vocab_layer_norm: nn::LayerNorm,
vocab_projector: nn::Linear,
}
impl DistilBertModelMaskedLM {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
let distil_bert_model = DistilBertModel::new(&p, config);
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let vocab_layer_norm = nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(&(p / "vocab_projector"), config.dim, config.vocab_size, Default::default());
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
}
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
let output = output
.apply(&self.vocab_transform)
.gelu()
.apply(&self.vocab_layer_norm)
.apply(&self.vocab_projector);
Ok((output, all_hidden_states, all_attentions))
}
}
pub struct DistilBertForQuestionAnswering {
distil_bert_model: DistilBertModel,
qa_outputs: nn::Linear,
dropout: Dropout,
}
impl DistilBertForQuestionAnswering {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
let distil_bert_model = DistilBertModel::new(&p, config);
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default());
let dropout = Dropout::new(config.qa_dropout);
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
}
pub fn forward_t(&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
}
}
pub struct DistilBertForTokenClassification {
distil_bert_model: DistilBertModel,
classifier: nn::Linear,
dropout: Dropout,
}
impl DistilBertForTokenClassification {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
}
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
}
}