use crate::pipelines::common::TokenizerOption;
use crate::pipelines::keywords_extraction::tokenizer::StopWordsTokenizer;
#[cfg(feature = "remote")]
use crate::pipelines::sentence_embeddings::SentenceEmbeddingsModelType;
use crate::pipelines::sentence_embeddings::{
SentenceEmbeddingsConfig, SentenceEmbeddingsModel, SentenceEmbeddingsSentenceBertConfig,
SentenceEmbeddingsTokenizerConfig,
};
use crate::{Config, RustBertError};
use regex::Regex;
use rust_tokenizers::Offset;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::cmp::min;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Keyword {
pub text: String,
pub score: f32,
pub offsets: Vec<Offset>,
}
pub enum KeywordScorerType {
CosineSimilarity,
MaximalMarginRelevance,
MaxSum,
}
pub struct KeywordExtractionConfig<'a> {
pub sentence_embeddings_config: SentenceEmbeddingsConfig,
pub tokenizer_stopwords: Option<HashSet<&'a str>>,
pub tokenizer_pattern: Option<Regex>,
pub tokenizer_forbidden_ngram_chars: Option<&'a [char]>,
pub scorer_type: KeywordScorerType,
pub ngram_range: (usize, usize),
pub num_keywords: usize,
pub diversity: Option<f64>,
pub max_sum_candidates: Option<usize>,
}
#[cfg(feature = "remote")]
impl Default for KeywordExtractionConfig<'_> {
fn default() -> Self {
let sentence_embeddings_config =
SentenceEmbeddingsConfig::from(SentenceEmbeddingsModelType::AllMiniLmL6V2);
Self {
sentence_embeddings_config,
tokenizer_stopwords: None,
tokenizer_pattern: None,
tokenizer_forbidden_ngram_chars: None,
scorer_type: KeywordScorerType::CosineSimilarity,
ngram_range: (1, 1),
num_keywords: 5,
diversity: None,
max_sum_candidates: None,
}
}
}
pub struct KeywordExtractionModel<'a> {
pub sentence_embeddings_model: SentenceEmbeddingsModel,
pub tokenizer: StopWordsTokenizer<'a>,
scorer_type: KeywordScorerType,
ngram_range: (usize, usize),
num_keywords: usize,
diversity: Option<f64>,
max_sum_candidates: Option<usize>,
}
impl<'a> KeywordExtractionModel<'a> {
pub fn new(
config: KeywordExtractionConfig<'a>,
) -> Result<KeywordExtractionModel<'a>, RustBertError> {
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
config
.sentence_embeddings_config
.tokenizer_config_resource
.get_local_path()?,
);
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
config
.sentence_embeddings_config
.sentence_bert_config_resource
.get_local_path()?,
);
let sentence_embeddings_model =
SentenceEmbeddingsModel::new(config.sentence_embeddings_config)?;
let do_lower_case = tokenizer_config
.do_lower_case
.unwrap_or(sentence_bert_config.do_lower_case);
let tokenizer = StopWordsTokenizer::new(
config.tokenizer_stopwords,
config.tokenizer_pattern,
do_lower_case,
config.tokenizer_forbidden_ngram_chars,
);
Ok(Self {
sentence_embeddings_model,
tokenizer,
scorer_type: config.scorer_type,
ngram_range: config.ngram_range,
num_keywords: config.num_keywords,
diversity: config.diversity,
max_sum_candidates: config.max_sum_candidates,
})
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.sentence_embeddings_model.get_tokenizer()
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.sentence_embeddings_model.get_tokenizer_mut()
}
pub fn predict<S>(&self, inputs: &[S]) -> Result<Vec<Vec<Keyword>>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let words = self.tokenizer.tokenize_list(inputs, self.ngram_range);
let (flat_word_list, document_boundaries) =
KeywordExtractionModel::flatten_word_list(&words);
let document_embeddings = self
.sentence_embeddings_model
.encode_as_tensor(inputs)?
.embeddings;
let word_embeddings = self
.sentence_embeddings_model
.encode_as_tensor(&flat_word_list)?;
let mut output_keywords: Vec<Vec<Keyword>> = Vec::new();
for (document_index, (start, end)) in document_boundaries.into_iter().enumerate() {
let mut document_keywords = Vec::new();
let document_embedding = document_embeddings
.select(0, document_index as i64)
.unsqueeze(0);
let word_embeddings = word_embeddings
.embeddings
.slice(0, start as i64, end as i64, 1);
let num_keywords = min(self.num_keywords, word_embeddings.size()[0] as usize);
let local_top_word_indices = self.scorer_type.score_keywords(
document_embedding,
word_embeddings,
num_keywords,
self.diversity,
self.max_sum_candidates,
);
for (index, score) in local_top_word_indices {
let word = flat_word_list[start + index];
document_keywords.push(Keyword {
text: word.to_string(),
score,
offsets: words[document_index].get(word).unwrap().clone(),
});
}
document_keywords.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
output_keywords.push(document_keywords)
}
Ok(output_keywords)
}
fn flatten_word_list(
words: &'a [HashMap<Cow<str>, Vec<Offset>>],
) -> (Vec<&'a Cow<'a, str>>, Vec<(usize, usize)>) {
let mut flat_word_list = Vec::new();
let mut doc_boundaries = Vec::with_capacity(words.len());
let mut current_index = 0;
for doc_words_map in words {
let doc_words = doc_words_map.keys();
let doc_words_len = doc_words_map.len();
flat_word_list.extend(doc_words);
doc_boundaries.push((current_index, current_index + doc_words_len));
current_index += doc_words_len;
}
(flat_word_list, doc_boundaries)
}
}