use tch::{Device, Kind};
use crate::bart::BartGenerator;
use crate::common::error::RustBertError;
use crate::pegasus::PegasusConditionalGenerator;
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::prophetnet::ProphetNetConditionalGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use crate::longt5::LongT5Generator;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXConditionalGenerator;
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
pub struct SummarizationConfig {
pub model_type: ModelType,
pub model_resource: ModelResource,
pub config_resource: Box<dyn ResourceProvider + Send>,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
pub min_length: i64,
pub max_length: Option<i64>,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub device: Device,
pub kind: Option<Kind>,
}
impl SummarizationConfig {
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
) -> SummarizationConfig
where
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
SummarizationConfig {
model_type,
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 56,
max_length: Some(142),
do_sample: false,
early_stopping: true,
num_beams: 3,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
#[cfg(feature = "remote")]
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig::new(
ModelType::Bart,
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
))),
RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
Some(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
)
}
}
impl From<SummarizationConfig> for GenerateConfig {
fn from(config: SummarizationConfig) -> GenerateConfig {
GenerateConfig {
model_type: config.model_type,
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
kind: config.kind,
}
}
}
pub enum SummarizationOption {
Bart(BartGenerator),
T5(T5Generator),
LongT5(LongT5Generator),
ProphetNet(ProphetNetConditionalGenerator),
Pegasus(PegasusConditionalGenerator),
#[cfg(feature = "onnx")]
ONNX(ONNXConditionalGenerator),
}
impl SummarizationOption {
pub fn new(config: SummarizationConfig) -> Result<Self, RustBertError> {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
ONNXConditionalGenerator::new(config.into(), None, None)?,
)),
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(BartGenerator::new(
config.into(),
)?)),
(ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
(ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(LongT5Generator::new(
config.into(),
)?)),
(ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
ProphetNetConditionalGenerator::new(config.into())?,
)),
(ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
PegasusConditionalGenerator::new(config.into())?,
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Summarization not implemented for {:?}!",
config.model_type
))),
}
}
pub fn new_with_tokenizer(
config: SummarizationConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
)),
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(
BartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
(ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(
LongT5Generator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
ProphetNetConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
PegasusConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Summarization not implemented for {:?}!",
config.model_type
))),
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::T5(_) => ModelType::T5,
Self::LongT5(_) => ModelType::LongT5,
Self::ProphetNet(_) => ModelType::ProphetNet,
Self::Pegasus(_) => ModelType::Pegasus,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::Bart(model_ref) => model_ref.get_tokenizer(),
Self::T5(model_ref) => model_ref.get_tokenizer(),
Self::LongT5(model_ref) => model_ref.get_tokenizer(),
Self::ProphetNet(model_ref) => model_ref.get_tokenizer(),
Self::Pegasus(model_ref) => model_ref.get_tokenizer(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer(),
}
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
match self {
Self::Bart(model_ref) => model_ref.get_tokenizer_mut(),
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
Self::LongT5(model_ref) => model_ref.get_tokenizer_mut(),
Self::ProphetNet(model_ref) => model_ref.get_tokenizer_mut(),
Self::Pegasus(model_ref) => model_ref.get_tokenizer_mut(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
}
}
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
Ok(match *self {
Self::Bart(ref model) => model
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::LongT5(ref model) => model
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
Self::Pegasus(ref model) => model
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.generate(prompt_texts, None)?
.into_iter()
.map(|output| output.text)
.collect(),
})
}
}
pub struct SummarizationModel {
model: SummarizationOption,
prefix: Option<String>,
}
impl SummarizationModel {
pub fn new(
summarization_config: SummarizationConfig,
) -> Result<SummarizationModel, RustBertError> {
let prefix = match summarization_config.model_type {
ModelType::T5 => Some("summarize: ".to_string()),
_ => None,
};
let model = SummarizationOption::new(summarization_config)?;
Ok(SummarizationModel { model, prefix })
}
pub fn new_with_tokenizer(
summarization_config: SummarizationConfig,
tokenizer: TokenizerOption,
) -> Result<SummarizationModel, RustBertError> {
let prefix = match summarization_config.model_type {
ModelType::T5 => Some("summarize: ".to_string()),
_ => None,
};
let model = SummarizationOption::new_with_tokenizer(summarization_config, tokenizer)?;
Ok(SummarizationModel { model, prefix })
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.model.get_tokenizer()
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.model.get_tokenizer_mut()
}
pub fn summarize<S>(&self, texts: &[S]) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
match &self.prefix {
None => self.model.generate(Some(texts)),
Some(prefix) => {
let texts = texts
.iter()
.map(|text| format!("{}{}", prefix, text.as_ref()))
.collect::<Vec<String>>();
self.model.generate(Some(&texts))
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = SummarizationConfig::default();
let _: Box<dyn Send> = Box::new(SummarizationModel::new(config));
}
}