use std::borrow::Borrow;
use std::convert::{TryFrom, TryInto};
use rust_tokenizers::tokenizer::TruncationStrategy;
use tch::{nn, Tensor};
use crate::albert::AlbertForSentenceEmbeddings;
use crate::bert::BertForSentenceEmbeddings;
use crate::distilbert::DistilBertForSentenceEmbeddings;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::pipelines::sentence_embeddings::layers::{Dense, DenseConfig, Pooling, PoolingConfig};
use crate::pipelines::sentence_embeddings::{
AttentionHead, AttentionLayer, AttentionOutput, Embedding, SentenceEmbeddingsConfig,
SentenceEmbeddingsModulesConfig, SentenceEmbeddingsSentenceBertConfig,
SentenceEmbeddingsTokenizerConfig,
};
use crate::roberta::RobertaForSentenceEmbeddings;
use crate::t5::T5ForSentenceEmbeddings;
use crate::{Config, RustBertError};
pub enum SentenceEmbeddingsOption {
Bert(BertForSentenceEmbeddings),
DistilBert(DistilBertForSentenceEmbeddings),
Roberta(RobertaForSentenceEmbeddings),
Albert(AlbertForSentenceEmbeddings),
T5(T5ForSentenceEmbeddings),
}
impl SentenceEmbeddingsOption {
pub fn new<'p, P>(
transformer_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
use SentenceEmbeddingsOption::*;
let option = match transformer_type {
ModelType::Bert => Bert(BertForSentenceEmbeddings::new(p, &(config.try_into()?))),
ModelType::DistilBert => DistilBert(DistilBertForSentenceEmbeddings::new(
p,
&(config.try_into()?),
)),
ModelType::Roberta => Roberta(RobertaForSentenceEmbeddings::new_with_optional_pooler(
p,
&(config.try_into()?),
false,
)),
ModelType::Albert => Albert(AlbertForSentenceEmbeddings::new(p, &(config.try_into()?))),
ModelType::T5 => T5(T5ForSentenceEmbeddings::new(p, &(config.try_into()?))),
_ => {
return Err(RustBertError::InvalidConfigurationError(format!(
"Unsupported transformer model {transformer_type:?} for Sentence Embeddings"
)));
}
};
Ok(option)
}
pub fn forward(
&self,
tokens_ids: &Tensor,
tokens_masks: &Tensor,
) -> Result<(Tensor, Option<Vec<Tensor>>), RustBertError> {
match self {
Self::Bert(transformer) => transformer
.forward_t(
Some(tokens_ids),
Some(tokens_masks),
None,
None,
None,
None,
None,
false,
)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions,
)
}),
Self::DistilBert(transformer) => transformer
.forward_t(Some(tokens_ids), Some(tokens_masks), None, false)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions,
)
}),
Self::Roberta(transformer) => transformer
.forward_t(
Some(tokens_ids),
Some(tokens_masks),
None,
None,
None,
None,
None,
false,
)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions,
)
}),
Self::Albert(transformer) => transformer
.forward_t(
Some(tokens_ids),
Some(tokens_masks),
None,
None,
None,
false,
)
.map(|transformer_output| {
(
transformer_output.hidden_state,
transformer_output.all_attentions.map(|attentions| {
attentions
.into_iter()
.map(|tensors| {
let num_inner_groups = tensors.len() as f64;
tensors.into_iter().sum::<Tensor>() / num_inner_groups
})
.collect()
}),
)
}),
Self::T5(transformer) => transformer.forward(tokens_ids, tokens_masks),
}
}
}
pub struct SentenceEmbeddingsModel {
sentence_bert_config: SentenceEmbeddingsSentenceBertConfig,
tokenizer: TokenizerOption,
tokenizer_truncation_strategy: TruncationStrategy,
var_store: nn::VarStore,
transformer: SentenceEmbeddingsOption,
transformer_config: ConfigOption,
pooling_layer: Pooling,
dense_layer: Option<Dense>,
normalize_embeddings: bool,
embeddings_dim: i64,
}
impl SentenceEmbeddingsModel {
pub fn new(config: SentenceEmbeddingsConfig) -> Result<Self, RustBertError> {
let transformer_type = config.transformer_type;
let tokenizer_vocab_resource = &config.tokenizer_vocab_resource;
let tokenizer_merges_resource = &config.tokenizer_merges_resource;
let tokenizer_config_resource = &config.tokenizer_config_resource;
let sentence_bert_config_resource = &config.sentence_bert_config_resource;
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
tokenizer_config_resource.get_local_path()?,
);
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
sentence_bert_config_resource.get_local_path()?,
);
let tokenizer = TokenizerOption::from_file(
transformer_type,
tokenizer_vocab_resource
.get_local_path()?
.to_string_lossy()
.as_ref(),
tokenizer_merges_resource
.as_ref()
.map(|resource| resource.get_local_path())
.transpose()?
.map(|path| path.to_string_lossy().into_owned())
.as_deref(),
tokenizer_config
.do_lower_case
.unwrap_or(sentence_bert_config.do_lower_case),
tokenizer_config.strip_accents,
tokenizer_config.add_prefix_space,
)?;
Self::new_with_tokenizer(config, tokenizer)
}
pub fn new_with_tokenizer(
config: SentenceEmbeddingsConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
let SentenceEmbeddingsConfig {
modules_config_resource,
sentence_bert_config_resource,
tokenizer_config_resource: _,
tokenizer_vocab_resource: _,
tokenizer_merges_resource: _,
transformer_type,
transformer_config_resource,
transformer_weights_resource,
pooling_config_resource,
dense_config_resource,
dense_weights_resource,
device,
kind,
} = config;
let modules =
SentenceEmbeddingsModulesConfig::from_file(modules_config_resource.get_local_path()?)
.validate()?;
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
sentence_bert_config_resource.get_local_path()?,
);
let mut var_store = nn::VarStore::new(device);
let transformer_config = ConfigOption::from_file(
transformer_type,
transformer_config_resource.get_local_path()?,
);
let transformer =
SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
crate::resources::load_weights(
&transformer_weights_resource,
&mut var_store,
kind,
device,
)?;
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
let mut embeddings_dim = pooling_config.word_embedding_dimension;
let pooling_layer = Pooling::new(pooling_config);
let dense_layer = if modules.dense_module().is_some() {
let dense_config =
DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?);
embeddings_dim = dense_config.out_features;
Some(Dense::new(
dense_config,
dense_weights_resource.unwrap().get_local_path()?,
device,
)?)
} else {
None
};
let normalize_embeddings = modules.has_normalization();
Ok(Self {
tokenizer,
sentence_bert_config,
tokenizer_truncation_strategy: TruncationStrategy::LongestFirst,
var_store,
transformer,
transformer_config,
pooling_layer,
dense_layer,
normalize_embeddings,
embeddings_dim,
})
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
pub fn set_tokenizer_truncation(&mut self, truncation_strategy: TruncationStrategy) {
self.tokenizer_truncation_strategy = truncation_strategy;
}
pub fn get_embedding_dim(&self) -> Result<i64, RustBertError> {
Ok(self.embeddings_dim)
}
pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOutput
where
S: AsRef<str> + Send + Sync,
{
let tokenized_input = self.tokenizer.encode_list(
inputs,
self.sentence_bert_config.max_seq_length,
&self.tokenizer_truncation_strategy,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap_or(0);
let pad_token_id = self.tokenizer.get_pad_id().unwrap_or(0);
let tokens_ids = tokenized_input
.into_iter()
.map(|input| {
let mut token_ids = input.token_ids;
token_ids.extend(vec![pad_token_id; max_len - token_ids.len()]);
token_ids
})
.collect::<Vec<_>>();
let tokens_masks = tokens_ids
.iter()
.map(|input| {
Tensor::from_slice(
&input
.iter()
.map(|&e| i64::from(e != pad_token_id))
.collect::<Vec<_>>(),
)
})
.collect::<Vec<_>>();
let tokens_ids = tokens_ids
.into_iter()
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
SentenceEmbeddingsTokenizerOutput {
tokens_ids,
tokens_masks,
}
}
pub fn encode_as_tensor<S>(
&self,
inputs: &[S],
) -> Result<SentenceEmbeddingsModelOutput, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let SentenceEmbeddingsTokenizerOutput {
tokens_ids,
tokens_masks,
} = self.tokenize(inputs);
if tokens_ids.is_empty() {
return Err(RustBertError::ValueError(
"No n-gram found in the document. \
Try allowing smaller n-gram sizes or relax stopword/forbidden characters criteria."
.to_string(),
));
}
let tokens_ids = Tensor::stack(&tokens_ids, 0).to(self.var_store.device());
let tokens_masks = Tensor::stack(&tokens_masks, 0).to(self.var_store.device());
let (tokens_embeddings, all_attentions) =
tch::no_grad(|| self.transformer.forward(&tokens_ids, &tokens_masks))?;
let mean_pool =
tch::no_grad(|| self.pooling_layer.forward(tokens_embeddings, &tokens_masks));
let maybe_linear = if let Some(dense_layer) = &self.dense_layer {
tch::no_grad(|| dense_layer.forward(&mean_pool))
} else {
mean_pool
};
let maybe_normalized = if self.normalize_embeddings {
let norm = &maybe_linear
.norm_scalaropt_dim(2, [1], true)
.clamp_min(1e-12)
.expand_as(&maybe_linear);
maybe_linear / norm
} else {
maybe_linear
};
Ok(SentenceEmbeddingsModelOutput {
embeddings: maybe_normalized,
all_attentions,
})
}
pub fn encode<S>(&self, inputs: &[S]) -> Result<Vec<Embedding>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let SentenceEmbeddingsModelOutput { embeddings, .. } = self.encode_as_tensor(inputs)?;
Ok(Vec::try_from(embeddings)?)
}
fn nb_layers(&self) -> usize {
use SentenceEmbeddingsOption::*;
match (&self.transformer, &self.transformer_config) {
(Bert(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
(Bert(_), _) => unreachable!(),
(DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_layers as usize,
(DistilBert(_), _) => unreachable!(),
(Roberta(_), ConfigOption::Bert(conf)) => conf.num_hidden_layers as usize,
(Roberta(_), _) => unreachable!(),
(Albert(_), ConfigOption::Albert(conf)) => conf.num_hidden_layers as usize,
(Albert(_), _) => unreachable!(),
(T5(_), ConfigOption::T5(conf)) => conf.num_layers as usize,
(T5(_), _) => unreachable!(),
}
}
fn nb_heads(&self) -> usize {
use SentenceEmbeddingsOption::*;
match (&self.transformer, &self.transformer_config) {
(Bert(_), ConfigOption::Bert(conf)) => conf.num_attention_heads as usize,
(Bert(_), _) => unreachable!(),
(DistilBert(_), ConfigOption::DistilBert(conf)) => conf.n_heads as usize,
(DistilBert(_), _) => unreachable!(),
(Roberta(_), ConfigOption::Roberta(conf)) => conf.num_attention_heads as usize,
(Roberta(_), _) => unreachable!(),
(Albert(_), ConfigOption::Albert(conf)) => conf.num_attention_heads as usize,
(Albert(_), _) => unreachable!(),
(T5(_), ConfigOption::T5(conf)) => conf.num_heads as usize,
(T5(_), _) => unreachable!(),
}
}
pub fn encode_with_attention<S>(
&self,
inputs: &[S],
) -> Result<(Vec<Embedding>, Vec<AttentionOutput>), RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let SentenceEmbeddingsModelOutput {
embeddings,
all_attentions,
} = self.encode_as_tensor(inputs)?;
let embeddings = Vec::try_from(embeddings)?;
let all_attentions = all_attentions.ok_or_else(|| {
RustBertError::InvalidConfigurationError("No attention outputted".into())
})?;
let attention_outputs = (0..inputs.len() as i64)
.map(|i| {
let mut attention_output = AttentionOutput::with_capacity(self.nb_layers());
for layer in all_attentions.iter() {
let mut attention_layer = AttentionLayer::with_capacity(self.nb_heads());
for head in 0..self.nb_heads() {
let attention_slice = layer
.slice(0, i, i + 1, 1)
.slice(1, head as i64, head as i64 + 1, 1)
.squeeze();
let attention_head = AttentionHead::try_from(attention_slice).unwrap();
attention_layer.push(attention_head);
}
attention_output.push(attention_layer);
}
attention_output
})
.collect::<Vec<AttentionOutput>>();
Ok((embeddings, attention_outputs))
}
}
pub struct SentenceEmbeddingsTokenizerOutput {
pub tokens_ids: Vec<Tensor>,
pub tokens_masks: Vec<Tensor>,
}
pub struct SentenceEmbeddingsModelOutput {
pub embeddings: Tensor,
pub all_attentions: Option<Vec<Tensor>>,
}