use crate::bert::{BertConfig, BertModel};
use crate::common::activations::_gelu;
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::roberta::embeddings::RobertaEmbeddings;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Tensor};
pub struct RobertaModelResources;
pub struct RobertaConfigResources;
pub struct RobertaVocabResources;
pub struct RobertaMergesResources;
impl RobertaModelResources {
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/model",
"https://huggingface.co/roberta-base/resolve/main/rust_model.ot",
);
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/model",
"https://huggingface.co/distilroberta-base/resolve/main/rust_model.ot",
);
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/model",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/rust_model.ot",
);
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/rust_model.ot",
);
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/rust_model.ot",
);
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/rust_model.ot",
);
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot",
);
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/model",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot",
);
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/model",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/rust_model.ot",
);
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/model",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/rust_model.ot",
);
}
impl RobertaConfigResources {
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/config",
"https://huggingface.co/roberta-base/resolve/main/config.json",
);
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/config",
"https://cdn.huggingface.co/distilroberta-base-config.json",
);
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/config",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/config.json",
);
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json",
);
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json",
);
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json",
);
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/config",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json",
);
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/config",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json",
);
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/config",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/config.json",
);
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/config",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/config.json",
);
}
impl RobertaVocabResources {
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/vocab",
"https://huggingface.co/roberta-base/resolve/main/vocab.json",
);
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/vocab",
"https://cdn.huggingface.co/distilroberta-base-vocab.json",
);
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/vocab",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/vocab.json",
);
pub const XLM_ROBERTA_NER_EN: (&'static str, &'static str) = (
"xlm-roberta-ner-en/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-english/resolve/main/sentencepiece.bpe.model",
);
pub const XLM_ROBERTA_NER_DE: (&'static str, &'static str) = (
"xlm-roberta-ner-de/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll03-german/resolve/main/sentencepiece.bpe.model",
);
pub const XLM_ROBERTA_NER_NL: (&'static str, &'static str) = (
"xlm-roberta-ner-nl/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/sentencepiece.bpe.model",
);
pub const XLM_ROBERTA_NER_ES: (&'static str, &'static str) = (
"xlm-roberta-ner-es/spiece",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/sentencepiece.bpe.model",
);
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/vocab",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json",
);
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/vocab",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/vocab.json",
);
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/vocab",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/vocab.json",
);
}
impl RobertaMergesResources {
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/merges",
"https://huggingface.co/roberta-base/resolve/main/merges.txt",
);
pub const DISTILROBERTA_BASE: (&'static str, &'static str) = (
"distilroberta-base/merges",
"https://cdn.huggingface.co/distilroberta-base-merges.txt",
);
pub const ROBERTA_QA: (&'static str, &'static str) = (
"roberta-qa/merges",
"https://huggingface.co/deepset/roberta-base-squad2/resolve/main/merges.txt",
);
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/merges",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt",
);
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/merges",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/merges.txt",
);
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/merges",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/merges.txt",
);
}
pub struct RobertaLMHead {
dense: nn::Linear,
decoder: LinearNoBias,
layer_norm: nn::LayerNorm,
bias: Tensor,
}
impl RobertaLMHead {
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaLMHead
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm = nn::layer_norm(
p / "layer_norm",
vec![config.hidden_size],
layer_norm_config,
);
let decoder = linear_no_bias(
p / "decoder",
config.hidden_size,
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
RobertaLMHead {
dense,
decoder,
layer_norm,
bias,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
(_gelu(&hidden_states.apply(&self.dense)))
.apply(&self.layer_norm)
.apply(&self.decoder)
+ &self.bias
}
}
pub type RobertaConfig = BertConfig;
pub struct RobertaForMaskedLM {
roberta: BertModel<RobertaEmbeddings>,
lm_head: RobertaLMHead,
}
impl RobertaForMaskedLM {
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMaskedLM
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let roberta =
BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
let lm_head = RobertaLMHead::new(p / "lm_head", config);
RobertaForMaskedLM { roberta, lm_head }
}
#[allow(rustdoc::invalid_html_tags)]
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> RobertaMaskedLMOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.unwrap();
let prediction_scores = self.lm_head.forward(&base_model_output.hidden_state);
RobertaMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct RobertaClassificationHead {
dense: nn::Linear,
dropout: Dropout,
out_proj: nn::Linear,
}
impl RobertaClassificationHead {
pub fn new<'p, P>(p: P, config: &BertConfig) -> Result<RobertaClassificationHead, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(RobertaClassificationHead {
dense,
dropout,
out_proj,
})
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
hidden_states
.select(1, 0)
.apply_t(&self.dropout, train)
.apply(&self.dense)
.tanh()
.apply_t(&self.dropout, train)
.apply(&self.out_proj)
}
}
pub struct RobertaForSequenceClassification {
roberta: BertModel<RobertaEmbeddings>,
classifier: RobertaClassificationHead,
}
impl RobertaForSequenceClassification {
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<RobertaForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let roberta =
BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
let classifier = RobertaClassificationHead::new(p / "classifier", config)?;
Ok(RobertaForSequenceClassification {
roberta,
classifier,
})
}
#[allow(rustdoc::invalid_html_tags)]
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> RobertaSequenceClassificationOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
None,
None,
train,
)
.unwrap();
let logits = self
.classifier
.forward_t(&base_model_output.hidden_state, train);
RobertaSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
#[allow(rustdoc::invalid_html_tags)]
pub struct RobertaForMultipleChoice {
roberta: BertModel<RobertaEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}
impl RobertaForMultipleChoice {
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMultipleChoice
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let roberta = BertModel::<RobertaEmbeddings>::new(p / "roberta", config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
RobertaForMultipleChoice {
roberta,
dropout,
classifier,
}
}
#[allow(rustdoc::invalid_html_tags)]
pub fn forward_t(
&self,
input_ids: &Tensor,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
train: bool,
) -> RobertaSequenceClassificationOutput {
let num_choices = input_ids.size()[1];
let input_ids = Some(input_ids.view((-1, *input_ids.size().last().unwrap())));
let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let token_type_ids =
token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let position_ids =
position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let base_model_output = self
.roberta
.forward_t(
input_ids.as_ref(),
mask.as_ref(),
token_type_ids.as_ref(),
position_ids.as_ref(),
None,
None,
None,
train,
)
.unwrap();
let logits = base_model_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
RobertaSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct RobertaForTokenClassification {
roberta: BertModel<RobertaEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}
impl RobertaForTokenClassification {
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<RobertaForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let roberta =
BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(RobertaForTokenClassification {
roberta,
dropout,
classifier,
})
}
#[allow(rustdoc::invalid_html_tags)]
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> RobertaTokenClassificationOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
None,
None,
train,
)
.unwrap();
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
RobertaTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct RobertaForQuestionAnswering {
roberta: BertModel<RobertaEmbeddings>,
qa_outputs: nn::Linear,
}
impl RobertaForQuestionAnswering {
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForQuestionAnswering
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let roberta =
BertModel::<RobertaEmbeddings>::new_with_optional_pooler(p / "roberta", config, false);
let num_labels = 2;
let qa_outputs = nn::linear(
p / "qa_outputs",
config.hidden_size,
num_labels,
Default::default(),
);
RobertaForQuestionAnswering {
roberta,
qa_outputs,
}
}
#[allow(rustdoc::invalid_html_tags)]
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> RobertaQuestionAnsweringOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
None,
None,
train,
)
.unwrap();
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
RobertaQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub type RobertaForSentenceEmbeddings = BertModel<RobertaEmbeddings>;
pub struct RobertaMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct RobertaSequenceClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct RobertaTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct RobertaQuestionAnsweringOutput {
pub start_logits: Tensor,
pub end_logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}