use crate::bert::BertConfig;
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::electra::embeddings::ElectraEmbeddings;
use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
use tch::{nn, Kind, Tensor};
pub struct ElectraModelResources;
pub struct ElectraConfigResources;
pub struct ElectraVocabResources;
impl ElectraModelResources {
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/model",
"https://huggingface.co/google/electra-base-generator/resolve/main/rust_model.ot",
);
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/model",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/rust_model.ot",
);
}
impl ElectraConfigResources {
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/config",
"https://huggingface.co/google/electra-base-generator/resolve/main/config.json",
);
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/config",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/config.json",
);
}
impl ElectraVocabResources {
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/vocab",
"https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
);
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/vocab",
"https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ElectraConfig {
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
pub embedding_size: i64,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: f32,
pub layer_norm_eps: Option<f64>,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub num_attention_heads: i64,
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub pad_token_id: i64,
pub output_past: Option<bool>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
}
impl Config for ElectraConfig {}
impl Default for ElectraConfig {
fn default() -> Self {
ElectraConfig {
hidden_act: Activation::gelu,
attention_probs_dropout_prob: 0.1,
embedding_size: 128,
hidden_dropout_prob: 0.1,
hidden_size: 256,
initializer_range: 0.02,
layer_norm_eps: Some(1e-12),
intermediate_size: 1024,
max_position_embeddings: 512,
num_attention_heads: 4,
num_hidden_layers: 12,
type_vocab_size: 2,
vocab_size: 30522,
pad_token_id: 0,
output_past: None,
output_attentions: None,
output_hidden_states: None,
id2label: None,
label2id: None,
}
}
}
pub struct ElectraModel {
embeddings: ElectraEmbeddings,
embeddings_project: Option<nn::Linear>,
encoder: BertEncoder,
}
impl ElectraModel {
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let embeddings = ElectraEmbeddings::new(p / "embeddings", config);
let embeddings_project = if config.embedding_size != config.hidden_size {
Some(nn::linear(
p / "embeddings_project",
config.embedding_size,
config.hidden_size,
Default::default(),
))
} else {
None
};
let bert_config = BertConfig {
hidden_act: config.hidden_act,
attention_probs_dropout_prob: config.attention_probs_dropout_prob,
hidden_dropout_prob: config.hidden_dropout_prob,
hidden_size: config.hidden_size,
initializer_range: config.initializer_range,
intermediate_size: config.intermediate_size,
max_position_embeddings: config.max_position_embeddings,
num_attention_heads: config.num_attention_heads,
num_hidden_layers: config.num_hidden_layers,
type_vocab_size: config.type_vocab_size,
vocab_size: config.vocab_size,
output_attentions: config.output_attentions,
output_hidden_states: config.output_hidden_states,
is_decoder: None,
id2label: config.id2label.clone(),
label2id: config.label2id.clone(),
};
let encoder = BertEncoder::new(p / "encoder", &bert_config);
ElectraModel {
embeddings,
embeddings_project,
encoder,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<ElectraModelOutput, RustBertError> {
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let calc_mask = if mask.is_none() {
Some(Tensor::ones(input_shape, (Kind::Int64, device)))
} else {
None
};
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => mask.unsqueeze(1).unsqueeze(1),
_ => {
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
let hidden_states = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let hidden_states = match &self.embeddings_project {
Some(layer) => hidden_states.apply(layer),
None => hidden_states,
};
let encoder_output = self.encoder.forward_t(
&hidden_states,
Some(&extended_attention_mask),
None,
None,
train,
);
Ok(ElectraModelOutput {
hidden_state: encoder_output.hidden_state,
all_hidden_states: encoder_output.all_hidden_states,
all_attentions: encoder_output.all_attentions,
})
}
}
pub struct ElectraDiscriminatorHead {
dense: nn::Linear,
dense_prediction: nn::Linear,
activation: TensorFunction,
}
impl ElectraDiscriminatorHead {
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminatorHead
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense_prediction = nn::linear(
p / "dense_prediction",
config.hidden_size,
1,
Default::default(),
);
let activation = config.hidden_act.get_function();
ElectraDiscriminatorHead {
dense,
dense_prediction,
activation,
}
}
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation.get_fn())(&output);
output.apply(&self.dense_prediction).squeeze()
}
}
pub struct ElectraGeneratorHead {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
activation: TensorFunction,
}
impl ElectraGeneratorHead {
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraGeneratorHead
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
Default::default(),
);
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.embedding_size,
Default::default(),
);
let activation = Activation::gelu.get_function();
ElectraGeneratorHead {
dense,
layer_norm,
activation,
}
}
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation.get_fn())(&output);
output.apply(&self.layer_norm)
}
}
pub struct ElectraForMaskedLM {
electra: ElectraModel,
generator_head: ElectraGeneratorHead,
lm_head: nn::Linear,
}
impl ElectraForMaskedLM {
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraForMaskedLM
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let electra = ElectraModel::new(p / "electra", config);
let generator_head = ElectraGeneratorHead::new(p / "generator_predictions", config);
let lm_head = nn::linear(
p / "generator_lm_head",
config.embedding_size,
config.vocab_size,
Default::default(),
);
ElectraForMaskedLM {
electra,
generator_head,
lm_head,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraMaskedLMOutput {
let base_model_output = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let hidden_states = self.generator_head.forward(&base_model_output.hidden_state);
let prediction_scores = hidden_states.apply(&self.lm_head);
ElectraMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct ElectraDiscriminator {
electra: ElectraModel,
discriminator_head: ElectraDiscriminatorHead,
}
impl ElectraDiscriminator {
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminator
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let electra = ElectraModel::new(p / "electra", config);
let discriminator_head =
ElectraDiscriminatorHead::new(p / "discriminator_predictions", config);
ElectraDiscriminator {
electra,
discriminator_head,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraDiscriminatorOutput {
let base_model_output = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let probabilities = self
.discriminator_head
.forward(&base_model_output.hidden_state)
.sigmoid();
ElectraDiscriminatorOutput {
probabilities,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct ElectraForTokenClassification {
electra: ElectraModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl ElectraForTokenClassification {
pub fn new<'p, P>(
p: P,
config: &ElectraConfig,
) -> Result<ElectraForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let electra = ElectraModel::new(p / "electra", config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"id2label must be provided for classifiers".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(ElectraForTokenClassification {
electra,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraTokenClassificationOutput {
let base_model_output = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
ElectraTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct ElectraModelOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct ElectraDiscriminatorOutput {
pub probabilities: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct ElectraMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct ElectraTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}