use tch::{nn, Tensor};
use crate::common::linear::{linear_no_bias, LinearNoBias};
use tch::nn::Init;
use crate::common::activations::_gelu;
use crate::roberta::embeddings::RobertaEmbeddings;
use crate::common::dropout::Dropout;
use crate::bert::{BertConfig, BertModel};
pub struct RobertaModelResources;
pub struct RobertaConfigResources;
pub struct RobertaVocabResources;
pub struct RobertaMergesResources;
impl RobertaModelResources {
pub const ROBERTA: (&'static str, &'static str) = ("roberta/model.ot", "https://cdn.huggingface.co/roberta-base-rust_model.ot");
}
impl RobertaConfigResources {
pub const ROBERTA: (&'static str, &'static str) = ("roberta/config.json", "https://cdn.huggingface.co/roberta-base-config.json");
}
impl RobertaVocabResources {
pub const ROBERTA: (&'static str, &'static str) = ("roberta/vocab.txt", "https://cdn.huggingface.co/roberta-base-vocab.json");
}
impl RobertaMergesResources {
pub const ROBERTA: (&'static str, &'static str) = ("roberta/merges.txt", "https://cdn.huggingface.co/roberta-base-merges.txt");
}
pub struct RobertaLMHead {
dense: nn::Linear,
decoder: LinearNoBias,
layer_norm: nn::LayerNorm,
bias: Tensor,
}
impl RobertaLMHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.hidden_size], layer_norm_config);
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
RobertaLMHead { dense, decoder, layer_norm, bias }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
(_gelu(&hidden_states.apply(&self.dense))).apply(&self.layer_norm).apply(&self.decoder) + &self.bias
}
}
pub struct RobertaForMaskedLM {
roberta: BertModel<RobertaEmbeddings>,
lm_head: RobertaLMHead,
}
impl RobertaForMaskedLM {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);
RobertaForMaskedLM { roberta, lm_head }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();
let prediction_scores = self.lm_head.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
}
}
pub struct RobertaClassificationHead {
dense: nn::Linear,
dropout: Dropout,
out_proj: nn::Linear,
}
impl RobertaClassificationHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default());
let dropout = Dropout::new(config.hidden_dropout_prob);
RobertaClassificationHead { dense, dropout, out_proj }
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
hidden_states
.select(1, 0)
.apply_t(&self.dropout, train)
.apply(&self.dense)
.tanh()
.apply_t(&self.dropout, train)
.apply(&self.out_proj)
}
}
pub struct RobertaForSequenceClassification {
roberta: BertModel<RobertaEmbeddings>,
classifier: RobertaClassificationHead,
}
impl RobertaForSequenceClassification {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
RobertaForSequenceClassification { roberta, classifier }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
let output = self.classifier.forward_t(&hidden_state, train);
(output, all_hidden_states, all_attentions)
}
}
pub struct RobertaForMultipleChoice {
roberta: BertModel<RobertaEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}
impl RobertaForMultipleChoice {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
RobertaForMultipleChoice { roberta, dropout, classifier }
}
pub fn forward_t(&self,
input_ids: Tensor,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let num_choices = input_ids.size()[1];
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
let flat_position_ids = match position_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
};
let flat_token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
};
let flat_mask = match mask {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
};
let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids,
None, &None, &None, train).unwrap();
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
(output, all_hidden_states, all_attentions)
}
}
pub struct RobertaForTokenClassification {
roberta: BertModel<RobertaEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}
impl RobertaForTokenClassification {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
RobertaForTokenClassification { roberta, dropout, classifier }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
}
}
pub struct RobertaForQuestionAnswering {
roberta: BertModel<RobertaEmbeddings>,
qa_outputs: nn::Linear,
}
impl RobertaForQuestionAnswering {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let num_labels = 2;
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());
RobertaForQuestionAnswering { roberta, qa_outputs }
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
(start_logits, end_logits, all_hidden_states, all_attentions)
}
}