use crate::bert::encoder::{BertEncoder, BertPooler};
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::{
bert::embeddings::{BertEmbedding, BertEmbeddings},
common::activations::TensorFunction,
};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Kind, Tensor};
pub struct BertModelResources;
pub struct BertConfigResources;
pub struct BertVocabResources;
impl BertModelResources {
pub const BERT: (&'static str, &'static str) = (
"bert/model",
"https://huggingface.co/bert-base-uncased/resolve/main/rust_model.ot",
);
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/model",
"https://huggingface.co/bert-large-uncased/resolve/main/rust_model.ot",
);
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/model",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/rust_model.ot",
);
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/model",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/rust_model.ot",
);
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/model",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/rust_model.ot",
);
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/model",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/rust_model.ot",
);
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/model",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/rust_model.ot",
);
}
impl BertConfigResources {
pub const BERT: (&'static str, &'static str) = (
"bert/config",
"https://huggingface.co/bert-base-uncased/resolve/main/config.json",
);
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/config",
"https://huggingface.co/bert-large-uncased/resolve/main/config.json",
);
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/config",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/config.json",
);
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/config",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
);
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/config.json",
);
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/config.json",
);
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json",
);
}
impl BertVocabResources {
pub const BERT: (&'static str, &'static str) = (
"bert/vocab",
"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
);
pub const BERT_LARGE: (&'static str, &'static str) = (
"bert-large/vocab",
"https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
);
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/vocab",
"https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/resolve/main/vocab.txt",
);
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/vocab",
"https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
);
pub const BERT_BASE_NLI_MEAN_TOKENS: (&'static str, &'static str) = (
"bert-base-nli-mean-tokens/vocab",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/vocab.txt",
);
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/vocab",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/vocab.txt",
);
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/vocab",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BertConfig {
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
pub hidden_dropout_prob: f64,
pub hidden_size: i64,
pub initializer_range: f32,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub num_attention_heads: i64,
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub is_decoder: Option<bool>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
}
impl Config for BertConfig {}
impl Default for BertConfig {
fn default() -> Self {
BertConfig {
hidden_act: Activation::gelu,
attention_probs_dropout_prob: 0.1,
hidden_dropout_prob: 0.1,
hidden_size: 768,
initializer_range: 0.02,
intermediate_size: 3072,
max_position_embeddings: 512,
num_attention_heads: 12,
num_hidden_layers: 12,
type_vocab_size: 2,
vocab_size: 30522,
output_attentions: None,
output_hidden_states: None,
is_decoder: None,
id2label: None,
label2id: None,
}
}
}
pub struct BertModel<T: BertEmbedding> {
embeddings: T,
encoder: BertEncoder,
pooler: Option<BertPooler>,
is_decoder: bool,
}
impl<T: BertEmbedding> BertModel<T> {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertModel<T>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let is_decoder = config.is_decoder.unwrap_or(false);
let embeddings = T::new(p / "embeddings", config);
let encoder = BertEncoder::new(p / "encoder", config);
let pooler = Some(BertPooler::new(p / "pooler", config));
BertModel {
embeddings,
encoder,
pooler,
is_decoder,
}
}
pub fn new_with_optional_pooler<'p, P>(
p: P,
config: &BertConfig,
add_pooling_layer: bool,
) -> BertModel<T>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let is_decoder = config.is_decoder.unwrap_or(false);
let embeddings = T::new(p / "embeddings", config);
let encoder = BertEncoder::new(p / "encoder", config);
let pooler = {
if add_pooling_layer {
Some(BertPooler::new(p / "pooler", config))
} else {
None
}
};
BertModel {
embeddings,
encoder,
pooler,
is_decoder,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> Result<BertModelOutput, RustBertError> {
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let calc_mask = Tensor::ones(&input_shape, (Kind::Int8, device));
let mask = mask.unwrap_or(&calc_mask);
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat([
input_shape[0],
input_shape[1],
1,
]);
let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
causal_mask * mask.unsqueeze(1).unsqueeze(1)
} else {
mask.unsqueeze(1).unsqueeze(1)
}
}
_ => {
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
let embedding_output = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let extended_attention_mask: Tensor = ((extended_attention_mask
.ones_like()
.bitwise_xor_tensor(&extended_attention_mask))
* -10000.0)
.to_kind(embedding_output.kind());
let encoder_extended_attention_mask: Option<Tensor> =
if self.is_decoder & encoder_hidden_states.is_some() {
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
let encoder_hidden_states_shape = encoder_hidden_states.size();
let encoder_mask = match encoder_mask {
Some(value) => value.copy(),
None => Tensor::ones(
[
encoder_hidden_states_shape[0],
encoder_hidden_states_shape[1],
],
(Kind::Int8, device),
),
};
match encoder_mask.dim() {
2 => Some(encoder_mask.unsqueeze(1).unsqueeze(1)),
3 => Some(encoder_mask.unsqueeze(1)),
_ => {
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
}
} else {
None
};
let encoder_output = self.encoder.forward_t(
&embedding_output,
Some(&extended_attention_mask),
encoder_hidden_states,
encoder_extended_attention_mask.as_ref(),
train,
);
let pooled_output = self
.pooler
.as_ref()
.map(|pooler| pooler.forward(&encoder_output.hidden_state));
Ok(BertModelOutput {
hidden_state: encoder_output.hidden_state,
pooled_output,
all_hidden_states: encoder_output.all_hidden_states,
all_attentions: encoder_output.all_attentions,
})
}
}
pub struct BertPredictionHeadTransform {
dense: nn::Linear,
activation: TensorFunction,
layer_norm: nn::LayerNorm,
}
impl BertPredictionHeadTransform {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertPredictionHeadTransform
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let activation = config.hidden_act.get_function();
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
BertPredictionHeadTransform {
dense,
activation,
layer_norm,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.activation.get_fn()(&hidden_states.apply(&self.dense)).apply(&self.layer_norm)
}
}
pub struct BertLMPredictionHead {
transform: BertPredictionHeadTransform,
decoder: LinearNoBias,
bias: Tensor,
}
impl BertLMPredictionHead {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertLMPredictionHead
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "predictions";
let transform = BertPredictionHeadTransform::new(&p / "transform", config);
let decoder = linear_no_bias(
&p / "decoder",
config.hidden_size,
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
BertLMPredictionHead {
transform,
decoder,
bias,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.transform.forward(hidden_states).apply(&self.decoder) + &self.bias
}
}
pub struct BertForMaskedLM {
bert: BertModel<BertEmbeddings>,
cls: BertLMPredictionHead,
}
impl BertForMaskedLM {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMaskedLM
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
let cls = BertLMPredictionHead::new(p / "cls", config);
BertForMaskedLM { bert, cls }
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_mask: Option<&Tensor>,
train: bool,
) -> BertMaskedLMOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.unwrap();
let prediction_scores = self.cls.forward(&base_model_output.hidden_state);
BertMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct BertForSequenceClassification {
bert: BertModel<BertEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}
impl BertForSequenceClassification {
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<BertForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let bert = BertModel::new(p / "bert", config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(BertForSequenceClassification {
bert,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> BertSequenceClassificationOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
None,
None,
train,
)
.unwrap();
let logits = base_model_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
BertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct BertForMultipleChoice {
bert: BertModel<BertEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}
impl BertForMultipleChoice {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMultipleChoice
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let bert = BertModel::new(p / "bert", config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
BertForMultipleChoice {
bert,
dropout,
classifier,
}
}
pub fn forward_t(
&self,
input_ids: &Tensor,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
train: bool,
) -> BertSequenceClassificationOutput {
let num_choices = input_ids.size()[1];
let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
let mask = mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let token_type_ids =
token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let position_ids =
position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let base_model_output = self
.bert
.forward_t(
Some(&input_ids),
mask.as_ref(),
token_type_ids.as_ref(),
position_ids.as_ref(),
None,
None,
None,
train,
)
.unwrap();
let logits = base_model_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
BertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct BertForTokenClassification {
bert: BertModel<BertEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}
impl BertForTokenClassification {
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<BertForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let bert = BertModel::new_with_optional_pooler(p / "bert", config, false);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(BertForTokenClassification {
bert,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> BertTokenClassificationOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
None,
None,
train,
)
.unwrap();
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
BertTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub struct BertForQuestionAnswering {
bert: BertModel<BertEmbeddings>,
qa_outputs: nn::Linear,
}
impl BertForQuestionAnswering {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForQuestionAnswering
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let bert = BertModel::new(p / "bert", config);
let num_labels = 2;
let qa_outputs = nn::linear(
p / "qa_outputs",
config.hidden_size,
num_labels,
Default::default(),
);
BertForQuestionAnswering { bert, qa_outputs }
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> BertQuestionAnsweringOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
None,
None,
train,
)
.unwrap();
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
BertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
pub type BertForSentenceEmbeddings = BertModel<BertEmbeddings>;
pub struct BertModelOutput {
pub hidden_state: Tensor,
pub pooled_output: Option<Tensor>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertSequenceClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertQuestionAnsweringOutput {
pub start_logits: Tensor,
pub end_logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
#[cfg(test)]
mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, ResourceProvider},
Config,
};
use super::*;
#[test]
#[ignore] fn bert_model_send() {
let config_resource = Box::new(RemoteResource::from_pretrained(BertConfigResources::BERT));
let config_path = config_resource.get_local_path().expect("");
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let config = BertConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(vs.root(), &config));
}
}