use crate::albert::AlbertForSequenceClassification;
use crate::bart::BartForSequenceClassification;
use crate::bert::BertForSequenceClassification;
use crate::deberta::DebertaForSequenceClassification;
use crate::deberta_v2::DebertaV2ForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::longformer::LongformerForSequenceClassification;
use crate::mobilebert::MobileBertForSequenceClassification;
use crate::pipelines::common::{
cast_var_store, ConfigOption, ModelResource, ModelType, TokenizerOption,
};
use crate::pipelines::sequence_classification::Label;
use crate::resources::ResourceProvider;
use crate::roberta::RobertaForSequenceClassification;
use crate::xlnet::XLNetForSequenceClassification;
use crate::RustBertError;
use rust_tokenizers::tokenizer::TruncationStrategy;
use rust_tokenizers::TokenizedInput;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::{config::ONNXEnvironmentConfig, ONNXEncoder};
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
use tch::kind::Kind::{Bool, Float};
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
pub struct ZeroShotClassificationConfig {
pub model_type: ModelType,
pub model_resource: ModelResource,
pub config_resource: Box<dyn ResourceProvider + Send>,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
pub lower_case: bool,
pub strip_accents: Option<bool>,
pub add_prefix_space: Option<bool>,
pub device: Device,
pub kind: Option<Kind>,
}
impl ZeroShotClassificationConfig {
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> ZeroShotClassificationConfig
where
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
ZeroShotClassificationConfig {
model_type,
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
device: Device::cuda_if_available(),
kind: None,
}
}
}
#[cfg(feature = "remote")]
impl Default for ZeroShotClassificationConfig {
fn default() -> ZeroShotClassificationConfig {
ZeroShotClassificationConfig {
model_type: ModelType::Bart,
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BartModelResources::BART_MNLI,
))),
config_resource: Box::new(RemoteResource::from_pretrained(
BartConfigResources::BART_MNLI,
)),
vocab_resource: Box::new(RemoteResource::from_pretrained(
BartVocabResources::BART_MNLI,
)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
BartMergesResources::BART_MNLI,
))),
lower_case: false,
strip_accents: None,
add_prefix_space: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum ZeroShotClassificationOption {
Bart(BartForSequenceClassification),
Deberta(DebertaForSequenceClassification),
DebertaV2(DebertaV2ForSequenceClassification),
Bert(BertForSequenceClassification),
DistilBert(DistilBertModelClassifier),
MobileBert(MobileBertForSequenceClassification),
Roberta(RobertaForSequenceClassification),
XLMRoberta(RobertaForSequenceClassification),
Albert(AlbertForSequenceClassification),
XLNet(XLNetForSequenceClassification),
Longformer(LongformerForSequenceClassification),
#[cfg(feature = "onnx")]
ONNX(ONNXEncoder),
}
impl ZeroShotClassificationOption {
pub fn new(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
match config.model_resource {
ModelResource::Torch(_) => Self::new_torch(config),
#[cfg(feature = "onnx")]
ModelResource::ONNX(_) => Self::new_onnx(config),
}
}
fn new_torch(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
let device = config.device;
let weights_path = config.model_resource.get_torch_local_path()?;
let mut var_store = VarStore::new(device);
let model_config =
&ConfigOption::from_file(config.model_type, config.config_resource.get_local_path()?);
let model_type = config.model_type;
let model = match model_type {
ModelType::Bart => {
if let ConfigOption::Bart(config) = model_config {
Ok(Self::Bart(
BartForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BartConfig for Bart!".to_string(),
))
}
}
ModelType::Deberta => {
if let ConfigOption::Deberta(config) = model_config {
Ok(Self::Deberta(
DebertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTa!".to_string(),
))
}
}
ModelType::DebertaV2 => {
if let ConfigOption::DebertaV2(config) = model_config {
Ok(Self::DebertaV2(
DebertaV2ForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DebertaConfig for DeBERTaV2!".to_string(),
))
}
}
ModelType::Bert => {
if let ConfigOption::Bert(config) = model_config {
Ok(Self::Bert(
BertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Bert!".to_string(),
))
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = model_config {
Ok(Self::DistilBert(
DistilBertModelClassifier::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a DistilBertConfig for DistilBert!".to_string(),
))
}
}
ModelType::MobileBert => {
if let ConfigOption::MobileBert(config) = model_config {
Ok(Self::MobileBert(
MobileBertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a MobileBertConfig for MobileBert!".to_string(),
))
}
}
ModelType::Roberta => {
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::Roberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a RobertaConfig for Roberta!".to_string(),
))
}
}
ModelType::XLMRoberta => {
if let ConfigOption::Bert(config) = model_config {
Ok(Self::XLMRoberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
))
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = model_config {
Ok(Self::Albert(
AlbertForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an AlbertConfig for Albert!".to_string(),
))
}
}
ModelType::XLNet => {
if let ConfigOption::XLNet(config) = model_config {
Ok(Self::XLNet(
XLNetForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply an AlbertConfig for Albert!".to_string(),
))
}
}
ModelType::Longformer => {
if let ConfigOption::Longformer(config) = model_config {
Ok(Self::Longformer(
LongformerForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a LongformerConfig for Longformer!".to_string(),
))
}
}
#[cfg(feature = "onnx")]
ModelType::ONNX => Err(RustBertError::InvalidConfigurationError(
"A `ModelType::ONNX` ModelType was provided in the configuration with `ModelResources::TORCH`, these are incompatible".to_string(),
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Zero shot classification not implemented for {model_type:?}!",
))),
}?;
var_store.load(weights_path)?;
cast_var_store(&mut var_store, config.kind, device);
Ok(model)
}
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
.encoder_path
.ok_or(RustBertError::InvalidConfigurationError(
"An encoder file must be provided for zero-shot classification ONNX models."
.to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::Deberta(_) => ModelType::Deberta,
Self::DebertaV2(_) => ModelType::DebertaV2,
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::MobileBert(_) => ModelType::MobileBert,
Self::Albert(_) => ModelType::Albert,
Self::XLNet(_) => ModelType::XLNet,
Self::Longformer(_) => ModelType::Longformer,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Tensor {
match *self {
Self::Bart(ref model) => {
model
.forward_t(
input_ids.expect("`input_ids` must be provided for BART models"),
mask,
None,
None,
None,
train,
)
.decoder_output
}
Self::Bert(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.logits
}
Self::Deberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in DeBERTa forward_t")
.logits
}
Self::DebertaV2(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in DeBERTaV2 forward_t")
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.logits
}
Self::MobileBert(ref model) => {
model
.forward_t(input_ids, None, None, input_embeds, mask, train)
.expect("Error in mobilebert forward_t")
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.logits
}
Self::Albert(ref model) => {
model
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.logits
}
Self::XLNet(ref model) => {
model
.forward_t(
input_ids,
mask,
None,
None,
None,
token_type_ids,
input_embeds,
train,
)
.logits
}
Self::Longformer(ref model) => {
model
.forward_t(
input_ids,
mask,
None,
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in Longformer forward pass.")
.logits
}
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.forward(
input_ids,
mask.map(|tensor| tensor.to_kind(Kind::Int64)).as_ref(),
token_type_ids,
position_ids,
input_embeds,
)
.expect("Error in ONNX forward pass.")
.logits
.unwrap(),
}
}
}
pub type ZeroShotTemplate = Box<dyn Fn(&str) -> String>;
pub struct ZeroShotClassificationModel {
tokenizer: TokenizerOption,
zero_shot_classifier: ZeroShotClassificationOption,
device: Device,
}
impl ZeroShotClassificationModel {
pub fn new(
config: ZeroShotClassificationConfig,
) -> Result<ZeroShotClassificationModel, RustBertError> {
let vocab_path = config.vocab_resource.get_local_path()?;
let merges_path = config
.merges_resource
.as_ref()
.map(|resource| resource.get_local_path())
.transpose()?;
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.as_deref().map(|path| path.to_str().unwrap()),
config.lower_case,
config.strip_accents,
config.add_prefix_space,
)?;
Self::new_with_tokenizer(config, tokenizer)
}
pub fn new_with_tokenizer(
config: ZeroShotClassificationConfig,
tokenizer: TokenizerOption,
) -> Result<ZeroShotClassificationModel, RustBertError> {
let device = config.device;
let zero_shot_classifier = ZeroShotClassificationOption::new(&config)?;
Ok(ZeroShotClassificationModel {
tokenizer,
zero_shot_classifier,
device,
})
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn prepare_for_model<'a, S, T>(
&self,
inputs: S,
labels: T,
template: Option<ZeroShotTemplate>,
max_len: usize,
) -> Result<(Tensor, Tensor, Tensor), RustBertError>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let label_sentences: Vec<String> = match template {
Some(function) => labels
.as_ref()
.iter()
.map(|label| function(label))
.collect(),
None => labels
.as_ref()
.iter()
.map(|label| format!("This example is about {label}."))
.collect(),
};
let text_pair_list = inputs
.as_ref()
.iter()
.flat_map(|input| {
label_sentences
.iter()
.map(move |label_sentence| (*input, label_sentence.as_str()))
})
.collect::<Vec<(&str, &str)>>();
let mut tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
text_pair_list.as_ref(),
max_len,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.ok_or_else(|| RustBertError::ValueError("Got empty iterator as input".to_string()))?;
let pad_id = self
.tokenizer
.get_pad_id()
.expect("The Tokenizer used for sequence classification should contain a PAD id");
let input_ids = tokenized_input
.iter_mut()
.map(|input| {
input.token_ids.resize(max_len, pad_id);
Tensor::from_slice(&(input.token_ids))
})
.collect::<Vec<_>>();
let token_type_ids = tokenized_input
.iter_mut()
.map(|input| {
input
.segment_ids
.resize(max_len, *input.segment_ids.last().unwrap_or(&0));
Tensor::from_slice(&(input.segment_ids))
})
.collect::<Vec<_>>();
let input_ids = Tensor::stack(input_ids.as_slice(), 0).to(self.device);
let token_type_ids = Tensor::stack(token_type_ids.as_slice(), 0)
.to(self.device)
.to_kind(Kind::Int64);
let mask = input_ids
.ne(self
.tokenizer
.get_pad_id()
.expect("The Tokenizer used for zero shot classification should contain a PAD id"))
.to_kind(Bool);
Ok((input_ids, mask, token_type_ids))
}
pub fn predict<'a, S, T>(
&self,
inputs: S,
labels: T,
template: Option<ZeroShotTemplate>,
max_length: usize,
) -> Result<Vec<Label>, RustBertError>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask, token_type_ids) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(&input_tensor),
Some(&mask),
Some(&token_type_ids),
None,
None,
false,
);
output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
});
let scores = output.softmax(1, Float).select(-1, -1);
let label_indices = scores.as_ref().argmax(-1, true).squeeze_dim(1);
let scores = scores
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze_dim(1);
let label_indices = label_indices.iter::<i64>()?.collect::<Vec<i64>>();
let scores = scores.iter::<f64>()?.collect::<Vec<f64>>();
let mut output_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() {
let label_string = labels.as_ref()[label_indices[sentence_idx] as usize].to_string();
let label = Label {
text: label_string,
score: scores[sentence_idx],
id: label_indices[sentence_idx],
sentence: sentence_idx,
};
output_labels.push(label)
}
Ok(output_labels)
}
pub fn predict_multilabel<'a, S, T>(
&self,
inputs: S,
labels: T,
template: Option<ZeroShotTemplate>,
max_length: usize,
) -> Result<Vec<Vec<Label>>, RustBertError>
where
S: AsRef<[&'a str]>,
T: AsRef<[&'a str]>,
{
let num_inputs = inputs.as_ref().len();
let (input_tensor, mask, token_type_ids) =
self.prepare_for_model(inputs.as_ref(), labels.as_ref(), template, max_length)?;
let output = no_grad(|| {
let output = self.zero_shot_classifier.forward_t(
Some(&input_tensor),
Some(&mask),
Some(&token_type_ids),
None,
None,
false,
);
output.view((num_inputs as i64, labels.as_ref().len() as i64, -1i64))
});
let scores = output.slice(-1, 0, 3, 2).softmax(-1, Float).select(-1, -1);
let mut output_labels = vec![];
for sentence_idx in 0..num_inputs {
let mut sentence_labels = vec![];
for (label_index, score) in scores
.select(0, sentence_idx as i64)
.iter::<f64>()?
.enumerate()
{
let label_string = labels.as_ref()[label_index].to_string();
let label = Label {
text: label_string,
score,
id: label_index as i64,
sentence: sentence_idx,
};
sentence_labels.push(label);
}
output_labels.push(sentence_labels);
}
Ok(output_labels)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = ZeroShotClassificationConfig::default();
let _: Box<dyn Send> = Box::new(ZeroShotClassificationModel::new(config));
}
}