use crate::bert::BertForSequenceClassification;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForSequenceClassification;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
TokenizedInput, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
#[derive(Debug, Serialize, Deserialize)]
pub struct Label {
pub text: String,
pub score: f64,
pub id: i64,
#[serde(default)]
pub sentence: usize,
}
pub struct SequenceClassificationConfig {
pub model_type: ModelType,
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Option<Resource>,
pub lower_case: bool,
pub device: Device,
}
impl SequenceClassificationConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
) -> SequenceClassificationConfig {
SequenceClassificationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
lower_case,
device: Device::cuda_if_available(),
}
}
}
impl Default for SequenceClassificationConfig {
fn default() -> SequenceClassificationConfig {
SequenceClassificationConfig {
model_type: ModelType::DistilBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
)),
merges_resource: None,
lower_case: true,
device: Device::cuda_if_available(),
}
}
}
pub enum SequenceClassificationOption {
Bert(BertForSequenceClassification),
DistilBert(DistilBertModelClassifier),
Roberta(RobertaForSequenceClassification),
}
impl SequenceClassificationOption {
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Bert(BertForSequenceClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Bert!");
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(
p, config,
))
} else {
panic!("You can only supply a DistilBertConfig for DistilBert!");
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Electra => {
panic!("SequenceClassification not implemented for Electra!");
}
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
}
}
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self {
Self::Bert(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::DistilBert(ref model) => model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
}
}
}
pub struct SequenceClassificationModel {
tokenizer: TokenizerOption,
sequence_classifier: SequenceClassificationOption,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
}
impl SequenceClassificationModel {
pub fn new(
config: SequenceClassificationConfig,
) -> failure::Fallible<SequenceClassificationModel> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(download_resource(merges_resource).expect("Failure downloading resource"))
} else {
None
};
let device = config.device;
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
);
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let sequence_classifier =
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?;
Ok(SequenceClassificationModel {
tokenizer,
sequence_classifier,
label_mapping,
var_store,
})
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input: Vec<TokenizedInput> =
self.tokenizer
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false,
);
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
});
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
let scores = output
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
let mut labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() {
let label_string = self
.label_mapping
.get(&label_indices[sentence_idx])
.unwrap()
.clone();
let label = Label {
text: label_string,
score: scores[sentence_idx],
id: label_indices[sentence_idx],
sentence: sentence_idx,
};
labels.push(label)
}
labels
}
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false,
);
output.sigmoid().detach().to(Device::Cpu)
});
let label_indices = output.as_ref().ge(threshold).nonzero();
let mut labels: Vec<Vec<Label>> = vec![];
let mut sequence_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.size()[0] {
let label_index_tensor = label_indices.get(sentence_idx);
let sentence_label = label_index_tensor
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let (sentence, id) = (sentence_label[0], sentence_label[1]);
if sentence as usize > labels.len() {
labels.push(sequence_labels);
sequence_labels = vec![];
}
let score = output.double_value(sentence_label.as_slice());
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
let label = Label {
text: label_string,
score,
id,
sentence: sentence as usize,
};
sequence_labels.push(label);
}
if sequence_labels.len() > 0 {
labels.push(sequence_labels);
}
labels
}
}