use crate::albert::AlbertForTokenClassification;
use crate::bert::{
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
};
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::distilbert::DistilBertForTokenClassification;
use crate::electra::ElectraForTokenClassification;
use crate::fnet::FNetForTokenClassification;
use crate::longformer::LongformerForTokenClassification;
use crate::mobilebert::MobileBertForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForTokenClassification;
use crate::xlnet::XLNetForTokenClassification;
use rust_tokenizers::tokenizer::Tokenizer;
use rust_tokenizers::{
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait,
TokenizedInput,
};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::cmp::min;
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Token {
pub text: String,
pub score: f64,
pub label: String,
pub label_index: i64,
pub sentence: usize,
pub index: u16,
pub word_index: u16,
pub offset: Option<Offset>,
pub mask: Mask,
}
impl TokenTrait for Token {
fn offset(&self) -> Option<Offset> {
self.offset
}
fn mask(&self) -> Mask {
self.mask
}
fn as_str(&self) -> &str {
self.text.as_str()
}
}
impl ConsolidatableTokens<Token> for Vec<Token> {
fn iter_consolidate_tokens(&self) -> ConsolidatedTokenIterator<Token> {
ConsolidatedTokenIterator::new(self)
}
}
#[derive(Debug)]
struct InputFeature {
input_ids: Vec<i64>,
offsets: Vec<Option<Offset>>,
mask: Vec<Mask>,
reference_feature: Vec<bool>,
example_index: usize,
}
type LabelAggregationFunction = Box<fn(&[Token]) -> (i64, String)>;
pub enum LabelAggregationOption {
First,
Last,
Mode,
Custom(LabelAggregationFunction),
}
pub struct TokenClassificationConfig {
pub model_type: ModelType,
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Option<Resource>,
pub lower_case: bool,
pub strip_accents: Option<bool>,
pub add_prefix_space: Option<bool>,
pub device: Device,
pub label_aggregation_function: LabelAggregationOption,
pub batch_size: usize,
}
impl TokenClassificationConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig {
TokenClassificationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
device: Device::cuda_if_available(),
label_aggregation_function,
batch_size: 64,
}
}
}
impl Default for TokenClassificationConfig {
fn default() -> TokenClassificationConfig {
TokenClassificationConfig {
model_type: ModelType::Bert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
merges_resource: None,
lower_case: false,
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
label_aggregation_function: LabelAggregationOption::First,
batch_size: 64,
}
}
}
pub enum TokenClassificationOption {
Bert(BertForTokenClassification),
DistilBert(DistilBertForTokenClassification),
MobileBert(MobileBertForTokenClassification),
Roberta(RobertaForTokenClassification),
XLMRoberta(RobertaForTokenClassification),
Electra(ElectraForTokenClassification),
Albert(AlbertForTokenClassification),
XLNet(XLNetForTokenClassification),
Longformer(LongformerForTokenClassification),
FNet(FNetForTokenClassification),
}
impl TokenClassificationOption {
pub fn new<'p, P>(
model_type: ModelType,
p: P,
config: &ConfigOption,
) -> Result<Self, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
Ok(TokenClassificationOption::Bert(
BertForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Bert!".to_string(),
))
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
Ok(TokenClassificationOption::DistilBert(
DistilBertForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DistilBertConfig for DistilBert!".to_string(),
))
}
}
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = config {
Ok(TokenClassificationOption::MobileBert(
MobileBertForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a MobileBertConfig for MobileBert!".to_string(),
))
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
Ok(TokenClassificationOption::Roberta(
RobertaForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
))
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = config {
Ok(TokenClassificationOption::XLMRoberta(
RobertaForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for XLMRoberta!".to_string(),
))
}
}
ModelType::Electra => {
if let ConfigOption::Electra(config) = config {
Ok(TokenClassificationOption::Electra(
ElectraForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
))
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
Ok(TokenClassificationOption::Albert(
AlbertForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an AlbertConfig for Albert!".to_string(),
))
}
}
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = config {
Ok(TokenClassificationOption::XLNet(
XLNetForTokenClassification::new(p, config).unwrap(),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an AlbertConfig for Albert!".to_string(),
))
}
}
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = config {
Ok(TokenClassificationOption::Longformer(
LongformerForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a LongformerConfig for Longformer!".to_string(),
))
}
}
ModelType::FNet => {
if let ConfigOption::FNet(config) = config {
Ok(TokenClassificationOption::FNet(
FNetForTokenClassification::new(p, config),
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an FNetConfig for FNet!".to_string(),
))
}
}
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Token classification not implemented for {:?}!",
model_type
))),
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::MobileBert(_) => ModelType::MobileBert,
Self::Electra(_) => ModelType::Electra,
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::Longformer(_) => ModelType::Longformer,
Self::FNet(_) => ModelType::FNet,
}
}
fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Tensor {
match *self {
Self::Bert(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.logits
}
Self::MobileBert(ref model) => {
model
.forward_t(input_ids, None, None, input_embeds, mask, train)
.expect("Error in mobilebert forward_t")
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.logits
}
Self::Electra(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.logits
}
Self::Albert(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.logits
}
Self::XLNet(ref model) => {
model
.forward_t(
input_ids,
mask,
None,
None,
None,
token_type_ids,
input_embeds,
train,
)
.logits
}
Self::Longformer(ref model) => {
model
.forward_t(
input_ids,
mask,
None,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in longformer forward_t")
.logits
}
Self::FNet(ref model) => {
model
.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train)
.expect("Error in fnet forward_t")
.logits
}
}
}
}
pub struct TokenClassificationModel {
tokenizer: TokenizerOption,
token_sequence_classifier: TokenClassificationOption,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
label_aggregation_function: LabelAggregationOption,
max_length: usize,
batch_size: usize,
}
impl TokenClassificationModel {
pub fn new(
config: TokenClassificationConfig,
) -> Result<TokenClassificationModel, RustBertError> {
let config_path = config.config_resource.get_local_path()?;
let vocab_path = config.vocab_resource.get_local_path()?;
let weights_path = config.model_resource.get_local_path()?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(merges_resource.get_local_path()?)
} else {
None
};
let device = config.device;
let label_aggregation_function = config.label_aggregation_function;
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.as_deref().map(|path| path.to_str().unwrap()),
config.lower_case,
config.strip_accents,
config.add_prefix_space,
)?;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let max_length = model_config
.get_max_len()
.map(|v| v as usize)
.unwrap_or(usize::MAX);
let token_sequence_classifier =
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone();
let batch_size = config.batch_size;
var_store.load(weights_path)?;
Ok(TokenClassificationModel {
tokenizer,
token_sequence_classifier,
label_mapping,
var_store,
label_aggregation_function,
max_length,
batch_size,
})
}
fn generate_features<S>(&self, input: S, example_index: usize) -> Vec<InputFeature>
where
S: AsRef<str>,
{
let tokenized_input = self.tokenizer.tokenize_with_offsets(input.as_ref());
let encoded_input = TokenIdsWithOffsets {
ids: self
.tokenizer
.convert_tokens_to_ids(&tokenized_input.tokens),
offsets: tokenized_input.offsets,
reference_offsets: tokenized_input.reference_offsets,
masks: tokenized_input.masks,
};
let sequence_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
TokenIdsWithOffsets {
ids: vec![],
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
)
.token_ids
.len();
let max_content_length = self.max_length - sequence_added_tokens;
let doc_stride = self.max_length / 4;
let mut spans: Vec<InputFeature> = vec![];
let mut start_token = 0_usize;
let total_length = encoded_input.ids.len();
while (spans.len() * doc_stride as usize) < encoded_input.ids.len() {
let end_token = min(start_token + max_content_length, total_length);
let sub_encoded_input = TokenIdsWithOffsets {
ids: encoded_input.ids[start_token..end_token].to_vec(),
offsets: encoded_input.offsets[start_token..end_token].to_vec(),
reference_offsets: encoded_input.reference_offsets[start_token..end_token].to_vec(),
masks: encoded_input.masks[start_token..end_token].to_vec(),
};
let encoded_span = self
.tokenizer
.build_input_with_special_tokens(sub_encoded_input, None);
let reference_feature = self.get_reference_feature_flag(
start_token,
end_token,
total_length,
doc_stride,
&encoded_span,
);
let feature = InputFeature {
input_ids: encoded_span.token_ids,
offsets: encoded_span.token_offsets,
mask: encoded_span.mask,
reference_feature,
example_index,
};
spans.push(feature);
if end_token == encoded_input.ids.len() {
break;
}
start_token = end_token - doc_stride;
}
spans
}
fn get_reference_feature_flag(
&self,
start_token: usize,
end_token: usize,
total_length: usize,
doc_stride: usize,
encoded_span: &TokenizedInput,
) -> Vec<bool> {
let start_cutoff = if start_token > 0 {
let leading_special_tokens = {
let mut counter = 0;
let mut masks = encoded_span.mask.iter();
while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
counter += 1;
}
counter
};
doc_stride / 2 + leading_special_tokens
} else {
0
};
let end_cutoff = if end_token < total_length {
let trailing_special_tokens = {
let mut counter = 0;
let mut masks = encoded_span.mask.iter().rev();
while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
counter += 1;
}
counter
};
encoded_span.token_ids.len() - doc_stride / 2 - trailing_special_tokens
} else {
encoded_span.token_ids.len()
};
let mut reference_feature = vec![true; encoded_span.token_ids.len()];
reference_feature[..start_cutoff]
.iter_mut()
.for_each(|v| *v = false);
reference_feature[end_cutoff..]
.iter_mut()
.for_each(|v| *v = false);
reference_feature
}
pub fn predict<S>(
&self,
input: &[S],
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Vec<Token>>
where
S: AsRef<str>,
{
let mut features: Vec<InputFeature> = input
.iter()
.enumerate()
.map(|(example_index, example)| self.generate_features(example, example_index))
.flatten()
.collect();
let mut example_tokens_map: HashMap<usize, Vec<Token>> = HashMap::new();
for example_idx in 0..input.len() {
example_tokens_map.insert(example_idx, Vec::new());
}
let mut start = 0usize;
let len_features = features.len();
while start < len_features {
let end = start + min(len_features - start, self.batch_size);
no_grad(|| {
let batch_features = &mut features[start..end];
let (input_ids, attention_masks) = self.pad_features(batch_features);
let output = self.token_sequence_classifier.forward_t(
Some(&input_ids),
Some(&attention_masks),
None,
None,
None,
false,
);
let score = output.exp() / output.exp().sum_dim_intlist(&[-1], true, Kind::Float);
let label_indices = score.argmax(-1, true);
for sentence_idx in 0..label_indices.size()[0] {
let labels = label_indices.get(sentence_idx);
let feature = &features[sentence_idx as usize];
let sentence_reference_flag = &feature.reference_feature;
let original_chars = input[feature.example_index]
.as_ref()
.chars()
.collect::<Vec<char>>();
let mut word_idx: u16 = 0;
for position_idx in sentence_reference_flag
.iter()
.enumerate()
.filter(|(_, flag)| **flag)
.map(|(pos, _)| pos)
{
let mask = feature.mask[position_idx];
if (mask == Mask::Special) & (!return_special) {
continue;
}
if !(mask == Mask::Continuation) {
word_idx += 1;
}
let token = {
self.decode_token(
&original_chars,
feature,
&input_ids,
&labels,
&score,
sentence_idx,
position_idx as i64,
word_idx,
)
};
example_tokens_map
.get_mut(&(feature.example_index))
.unwrap()
.push(token);
}
}
});
start = end;
}
let mut tokens = example_tokens_map
.into_iter()
.collect::<Vec<(usize, Vec<Token>)>>();
tokens.sort_by_key(|kv| kv.0);
let mut tokens = tokens
.into_iter()
.map(|(_, v)| v)
.collect::<Vec<Vec<Token>>>();
if consolidate_sub_tokens {
self.consolidate_tokens(&mut tokens, &self.label_aggregation_function);
}
tokens
}
fn pad_features(&self, features: &mut [InputFeature]) -> (Tensor, Tensor) {
let max_len = features
.iter()
.map(|feature| feature.input_ids.len())
.max()
.unwrap();
let attention_masks = features
.iter()
.map(|feature| &feature.input_ids)
.map(|input| {
let mut attention_mask = vec![1; input.len()];
attention_mask.append(&mut vec![0; max_len - attention_mask.len()]);
attention_mask
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
for feature in features.iter_mut() {
feature
.offsets
.append(&mut vec![None; max_len - feature.input_ids.len()]);
feature.input_ids.append(&mut vec![
self.tokenizer.get_pad_id().expect(
"Only tokenizers with a padding index can be used for token classification"
);
max_len - feature.input_ids.len()
]);
feature
.reference_feature
.append(&mut vec![false; max_len - feature.input_ids.len()]);
}
let padded_input_ids = features
.iter()
.map(|input| Tensor::of_slice(input.input_ids.as_slice()))
.collect::<Vec<_>>();
let input_ids = Tensor::stack(&padded_input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
(input_ids, attention_masks)
}
fn decode_token(
&self,
original_sentence_chars: &[char],
sentence_tokens: &InputFeature,
input_tensor: &Tensor,
labels: &Tensor,
score: &Tensor,
sentence_idx: i64,
position_idx: i64,
word_index: u16,
) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
let offsets = &sentence_tokens.offsets[position_idx as usize];
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLMRoberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Albert(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLNet(ref tokenizer) => {
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
_ => panic!(
"Token classification not implemented for {:?}!",
self.tokenizer.model_type()
),
},
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len());
let text = original_sentence_chars[start_char..end_char]
.iter()
.collect();
text
}
};
Token {
text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self
.label_mapping
.get(&label_id)
.expect("Index out of vocabulary bounds.")
.to_owned(),
label_index: label_id,
sentence: sentence_idx as usize,
index: position_idx as u16,
word_index,
offset: offsets.to_owned(),
mask: sentence_tokens.mask[position_idx as usize],
}
}
fn consolidate_tokens(
&self,
tokens: &mut Vec<Vec<Token>>,
label_aggregation_function: &LabelAggregationOption,
) {
for sequence_tokens in tokens {
let mut tokens_to_replace = vec![];
let token_iter = sequence_tokens.iter_consolidate_tokens();
let mut cursor = 0;
for sub_tokens in token_iter {
if sub_tokens.len() > 1 {
let (label_index, label) =
self.consolidate_labels(sub_tokens, label_aggregation_function);
let sentence = (sub_tokens[0]).sentence;
let index = (sub_tokens[0]).index;
let word_index = (sub_tokens[0]).word_index;
let offset_start = sub_tokens
.first()
.unwrap()
.offset
.as_ref()
.map(|offset| offset.begin);
let offset_end = sub_tokens
.last()
.unwrap()
.offset
.as_ref()
.map(|offset| offset.end);
let offset = if let (Some(offset_start), Some(offset_end)) =
(offset_start, offset_end)
{
Some(Offset::new(offset_start, offset_end))
} else {
None
};
let mut text = String::new();
let mut score = 1f64;
for current_sub_token in sub_tokens.iter() {
text.push_str(current_sub_token.text.as_str());
score *= if current_sub_token.label_index == label_index {
current_sub_token.score
} else {
1.0 - current_sub_token.score
};
}
let token = Token {
text,
score,
label,
label_index,
sentence,
index,
word_index,
offset,
mask: Default::default(),
};
tokens_to_replace.push(((cursor, cursor + sub_tokens.len()), token));
}
cursor += sub_tokens.len();
}
for ((start, end), token) in tokens_to_replace.into_iter().rev() {
sequence_tokens.splice(start..end, [token].iter().cloned());
}
}
}
fn consolidate_labels(
&self,
tokens: &[Token],
aggregation: &LabelAggregationOption,
) -> (i64, String) {
match aggregation {
LabelAggregationOption::First => {
let token = tokens.first().unwrap();
(token.label_index, token.label.clone())
}
LabelAggregationOption::Last => {
let token = tokens.last().unwrap();
(token.label_index, token.label.clone())
}
LabelAggregationOption::Mode => {
let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
m
});
counts
.into_iter()
.max_by(|a, b| a.1.cmp(&b.1))
.map(|((label_index, label), _)| (label_index, label.to_owned()))
.unwrap()
}
LabelAggregationOption::Custom(function) => function(tokens),
}
}
}