use rust_tokenizers::bert_tokenizer::BertTokenizer;
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
use crate::bert::{BertForTokenClassification, BertConfig, BertModelResources, BertConfigResources, BertVocabResources};
use crate::Config;
use crate::common::resources::{Resource, RemoteResource, download_resource};
#[derive(Debug)]
pub struct Entity {
pub word: String,
pub score: f64,
pub label: String,
}
pub struct NERConfig {
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub device: Device,
}
impl Default for NERConfig {
fn default() -> NERConfig {
NERConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
device: Device::cuda_if_available(),
}
}
}
pub struct NERModel {
tokenizer: BertTokenizer,
bert_sequence_classifier: BertForTokenClassification,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
}
impl NERModel {
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
let config_path = download_resource(&ner_config.config_resource)?;
let vocab_path = download_resource(&ner_config.vocab_resource)?;
let weights_path = download_resource(&ner_config.model_resource)?;
let device = ner_config.device;
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
let mut var_store = VarStore::new(device);
let config = BertConfig::from_file(config_path);
let bert_sequence_classifier = BertForTokenClassification::new(&var_store.root(), &config);
let label_mapping = config.id2label.expect("No label dictionary (id2label) provided in configuration file");
var_store.load(weights_path)?;
Ok(NERModel { tokenizer, bert_sequence_classifier, label_mapping, var_store })
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
let input_tensor = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.bert_sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false)
});
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);
let mut entities: Vec<Entity> = vec!();
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
for position_idx in 0..labels.size()[0] {
let label = labels.int64_value(&[position_idx]);
if label != 0 {
entities.push(Entity {
word: rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer::decode(&self.tokenizer, vec!(input_tensor.int64_value(&[sentence_idx, position_idx])), true, true),
score: score.double_value(&[sentence_idx, position_idx, label]),
label: self.label_mapping.get(&label).expect("Index out of vocabulary bounds.").to_owned(),
});
}
}
}
entities
}
}