use tch::Device;
use crate::bart::{
BartConfigResources, BartGenerator, BartMergesResources, BartModelResources, BartVocabResources,
};
use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::pegasus::PegasusConditionalGenerator;
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::prophetnet::ProphetNetConditionalGenerator;
use crate::t5::T5Generator;
pub struct SummarizationConfig {
pub model_type: ModelType,
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub merges_resource: Resource,
pub min_length: i64,
pub max_length: i64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub device: Device,
}
impl SummarizationConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Resource,
) -> SummarizationConfig {
SummarizationConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
device: Device::cuda_if_available(),
..Default::default()
}
}
}
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig {
model_type: ModelType::Bart,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
min_length: 56,
max_length: 142,
do_sample: false,
early_stopping: true,
num_beams: 3,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
}
}
}
impl From<SummarizationConfig> for GenerateConfig {
fn from(config: SummarizationConfig) -> GenerateConfig {
GenerateConfig {
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
}
}
}
pub enum SummarizationOption {
Bart(BartGenerator),
T5(T5Generator),
ProphetNet(ProphetNetConditionalGenerator),
Pegasus(PegasusConditionalGenerator),
}
impl SummarizationOption {
pub fn new(config: SummarizationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::Bart => Ok(SummarizationOption::Bart(BartGenerator::new(
config.into(),
)?)),
ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet(
ProphetNetConditionalGenerator::new(config.into())?,
)),
ModelType::Pegasus => Ok(SummarizationOption::Pegasus(
PegasusConditionalGenerator::new(config.into())?,
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Summarization not implemented for {:?}!",
config.model_type
))),
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::T5(_) => ModelType::T5,
Self::ProphetNet(_) => ModelType::ProphetNet,
Self::Pegasus(_) => ModelType::Pegasus,
}
}
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Vec<String>
where
S: AsRef<str> + Sync,
{
match *self {
Self::Bart(ref model) => model
.generate(prompt_texts, None)
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(prompt_texts, None)
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(prompt_texts, None)
.into_iter()
.map(|output| output.text)
.collect(),
Self::Pegasus(ref model) => model
.generate(prompt_texts, None)
.into_iter()
.map(|output| output.text)
.collect(),
}
}
}
pub struct SummarizationModel {
model: SummarizationOption,
prefix: Option<String>,
}
impl SummarizationModel {
pub fn new(
summarization_config: SummarizationConfig,
) -> Result<SummarizationModel, RustBertError> {
let prefix = match summarization_config.model_type {
ModelType::T5 => Some("summarize: ".to_string()),
_ => None,
};
let model = SummarizationOption::new(summarization_config)?;
Ok(SummarizationModel { model, prefix })
}
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
where
S: AsRef<str> + Sync,
{
match &self.prefix {
None => self.model.generate(Some(texts)),
Some(prefix) => {
let texts = texts
.iter()
.map(|text| format!("{}{}", prefix, text.as_ref()))
.collect::<Vec<String>>();
self.model.generate(Some(&texts))
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = SummarizationConfig::default();
let _: Box<dyn Send> = Box::new(SummarizationModel::new(config));
}
}