use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::bert::{Activation, BertConfig};
use crate::Config;
use crate::electra::embeddings::ElectraEmbeddings;
use tch::{nn, Tensor, Kind};
use crate::bert::encoder::BertEncoder;
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::dropout::Dropout;
pub struct ElectraModelResources;
pub struct ElectraConfigResources;
pub struct ElectraVocabResources;
impl ElectraModelResources {
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot");
}
impl ElectraConfigResources {
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json");
}
impl ElectraVocabResources {
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt");
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ElectraConfig {
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
pub embedding_size: i64,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: f32,
pub layer_norm_eps: Option<f64>,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub num_attention_heads: i64,
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub pad_token_id: i64,
pub output_past: Option<bool>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
}
impl Config<ElectraConfig> for ElectraConfig {}
pub struct ElectraModel {
embeddings: ElectraEmbeddings,
embeddings_project: Option<nn::Linear>,
encoder: BertEncoder,
}
impl ElectraModel {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel {
let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config);
let embeddings_project = if config.embedding_size != config.hidden_size {
Some(nn::linear(&(p / "embeddings_project"), config.embedding_size, config.hidden_size, Default::default()))
} else {
None
};
let bert_config = BertConfig {
hidden_act: config.hidden_act.clone(),
attention_probs_dropout_prob: config.attention_probs_dropout_prob,
hidden_dropout_prob: config.hidden_dropout_prob,
hidden_size: config.hidden_size,
initializer_range: config.initializer_range,
intermediate_size: config.intermediate_size,
max_position_embeddings: config.max_position_embeddings,
num_attention_heads: config.num_attention_heads,
num_hidden_layers: config.num_hidden_layers,
type_vocab_size: config.type_vocab_size,
vocab_size: config.vocab_size,
output_attentions: config.output_attentions,
output_hidden_states: config.output_hidden_states,
is_decoder: None,
id2label: config.id2label.clone(),
label2id: config.label2id.clone(),
};
let encoder = BertEncoder::new(&(p / "encoder"), &bert_config);
ElectraModel { embeddings, embeddings_project, encoder }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.size(), input_value.device())
}
None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
};
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device))
};
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => mask.unsqueeze(1).unsqueeze(1),
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); }
};
let hidden_states = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
Ok(value) => value,
Err(e) => { return Err(e); }
};
let hidden_states = match &self.embeddings_project {
Some(layer) => hidden_states.apply(layer),
None => hidden_states
};
let (hidden_state, all_hidden_states, all_attentions) =
self.encoder.forward_t(&hidden_states,
&Some(extended_attention_mask),
&None,
&None,
train);
Ok((hidden_state, all_hidden_states, all_attentions))
}
}
pub struct ElectraDiscriminatorHead {
dense: nn::Linear,
dense_prediction: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
}
impl ElectraDiscriminatorHead {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
let dense_prediction = nn::linear(&(p / "dense_prediction"), config.hidden_size, 1, Default::default());
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
});
ElectraDiscriminatorHead { dense, dense_prediction, activation }
}
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
output.apply(&self.dense_prediction).squeeze()
}
}
pub struct ElectraGeneratorHead {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
}
impl ElectraGeneratorHead {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], Default::default());
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
let activation = Box::new(_gelu);
ElectraGeneratorHead { layer_norm, dense, activation }
}
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
output.apply(&self.layer_norm)
}
}
pub struct ElectraForMaskedLM {
electra: ElectraModel,
generator_head: ElectraGeneratorHead,
lm_head: nn::Linear,
}
impl ElectraForMaskedLM {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
let electra = ElectraModel::new(&(p / "electra"), config);
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
let lm_head = nn::linear(&(p / "generator_lm_head"), config.embedding_size, config.vocab_size, Default::default());
ElectraForMaskedLM { electra, generator_head, lm_head }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
.unwrap();
let hidden_states = self.generator_head.forward(&hidden_states);
let hidden_states = hidden_states.apply(&self.lm_head);
(hidden_states, all_hidden_states, all_attentions)
}
}
pub struct ElectraDiscriminator {
electra: ElectraModel,
discriminator_head: ElectraDiscriminatorHead,
}
impl ElectraDiscriminator {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator {
let electra = ElectraModel::new(&(p / "electra"), config);
let discriminator_head = ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
ElectraDiscriminator { electra, discriminator_head }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
.unwrap();
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
(probabilities, all_hidden_states, all_attentions)
}
}
pub struct ElectraForTokenClassification {
electra: ElectraModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl ElectraForTokenClassification {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification {
let electra = ElectraModel::new(&(p / "electra"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
ElectraForTokenClassification { electra, dropout, classifier }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
.unwrap();
let output = hidden_states
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(output, all_hidden_states, all_attentions)
}
}