use tch::{Device, Kind};
use crate::common::error::RustBertError;
use crate::gpt2::GPT2Generator;
use crate::gpt_j::GptJGenerator;
use crate::gpt_neo::GptNeoGenerator;
use crate::openai_gpt::OpenAIGenerator;
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::reformer::ReformerGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use crate::xlnet::XLNetGenerator;
#[cfg(feature = "onnx")]
use crate::pipelines::onnx::ONNXCausalGenerator;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
pub struct TextGenerationConfig {
pub model_type: ModelType,
pub model_resource: ModelResource,
pub config_resource: Box<dyn ResourceProvider + Send>,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
pub min_length: i64,
pub max_length: Option<i64>,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub device: Device,
pub kind: Option<Kind>,
}
impl TextGenerationConfig {
pub fn new<RC, RV>(
model_type: ModelType,
model_resource: ModelResource,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
) -> TextGenerationConfig
where
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
TextGenerationConfig {
model_type,
model_resource,
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 0,
max_length: Some(56),
do_sample: true,
early_stopping: true,
num_beams: 5,
temperature: 1.0,
top_k: 0,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
#[cfg(feature = "remote")]
impl Default for TextGenerationConfig {
fn default() -> TextGenerationConfig {
TextGenerationConfig::new(
ModelType::GPT2,
ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::GPT2_MEDIUM,
))),
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
Some(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2_MEDIUM,
)),
)
}
}
impl From<TextGenerationConfig> for GenerateConfig {
fn from(config: TextGenerationConfig) -> GenerateConfig {
GenerateConfig {
model_type: config.model_type,
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
kind: config.kind,
}
}
}
pub enum TextGenerationOption {
GPT2(GPT2Generator),
GPT(OpenAIGenerator),
GPTNeo(GptNeoGenerator),
GPTJ(GptJGenerator),
XLNet(XLNetGenerator),
Reformer(ReformerGenerator),
T5(T5Generator),
#[cfg(feature = "onnx")]
ONNX(ONNXCausalGenerator),
}
impl TextGenerationOption {
pub fn new(config: TextGenerationConfig) -> Result<Self, RustBertError> {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
ONNXCausalGenerator::new(config.into(), None, None)?,
)),
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
config.into(),
)?)),
(ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
config.into(),
)?)),
(ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
config.into(),
)?)),
(ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
config.into(),
)?)),
(ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new(
config.into(),
)?)),
(ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
config.into(),
)?)),
(ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Text generation not implemented for {:?}!",
config.model_type
))),
}
}
pub fn new_with_tokenizer(
config: TextGenerationConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
)),
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(
GPT2Generator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(
OpenAIGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(
XLNetGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(
ReformerGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(
GptNeoGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(
GptJGenerator::new_with_tokenizer(config.into(), tokenizer)?,
)),
(ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Text generation not implemented for {:?}!",
config.model_type
))),
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::GPT(_) => ModelType::OpenAiGpt,
Self::GPT2(_) => ModelType::GPT2,
Self::GPTNeo(_) => ModelType::GPTNeo,
Self::GPTJ(_) => ModelType::GPTJ,
Self::XLNet(_) => ModelType::XLNet,
Self::Reformer(_) => ModelType::Reformer,
Self::T5(_) => ModelType::T5,
#[cfg(feature = "onnx")]
Self::ONNX(_) => ModelType::ONNX,
}
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT(model_ref) => model_ref.get_tokenizer(),
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
Self::GPTNeo(model_ref) => model_ref.get_tokenizer(),
Self::GPTJ(model_ref) => model_ref.get_tokenizer(),
Self::XLNet(model_ref) => model_ref.get_tokenizer(),
Self::Reformer(model_ref) => model_ref.get_tokenizer(),
Self::T5(model_ref) => model_ref.get_tokenizer(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer(),
}
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
match self {
Self::GPT(model_ref) => model_ref.get_tokenizer_mut(),
Self::GPT2(model_ref) => model_ref.get_tokenizer_mut(),
Self::GPTNeo(model_ref) => model_ref.get_tokenizer_mut(),
Self::GPTJ(model_ref) => model_ref.get_tokenizer_mut(),
Self::XLNet(model_ref) => model_ref.get_tokenizer_mut(),
Self::Reformer(model_ref) => model_ref.get_tokenizer_mut(),
Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
#[cfg(feature = "onnx")]
Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
}
}
pub fn generate_indices<S>(
&self,
prompt_texts: Option<&[S]>,
min_length: Option<i64>,
max_length: Option<i64>,
) -> Result<Vec<Vec<i64>>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let generate_options = Some(GenerateOptions {
min_length,
max_length,
..Default::default()
});
Ok(match *self {
Self::GPT(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPT2(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPTNeo(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::GPTJ(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::XLNet(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::Reformer(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
Self::T5(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
#[cfg(feature = "onnx")]
Self::ONNX(ref model) => model
.generate_indices(prompt_texts, generate_options)?
.into_iter()
.map(|output| output.indices)
.collect(),
})
}
pub fn half(&mut self) -> Result<(), RustBertError> {
match self {
Self::GPT(model_ref) => model_ref.half(),
Self::GPT2(model_ref) => model_ref.half(),
Self::GPTNeo(model_ref) => model_ref.half(),
Self::GPTJ(model_ref) => model_ref.half(),
Self::XLNet(model_ref) => model_ref.half(),
Self::Reformer(model_ref) => model_ref.half(),
Self::T5(model_ref) => model_ref.half(),
#[cfg(feature = "onnx")]
Self::ONNX(_) => Err(RustBertError::OrtError(
"Type casting not supported for ONNX models.".to_string(),
)),
}
}
pub fn float(&mut self) -> Result<(), RustBertError> {
match self {
Self::GPT(model_ref) => model_ref.float(),
Self::GPT2(model_ref) => model_ref.float(),
Self::GPTNeo(model_ref) => model_ref.float(),
Self::GPTJ(model_ref) => model_ref.float(),
Self::XLNet(model_ref) => model_ref.float(),
Self::Reformer(model_ref) => model_ref.float(),
Self::T5(model_ref) => model_ref.float(),
#[cfg(feature = "onnx")]
Self::ONNX(_) => Err(RustBertError::OrtError(
"Type casting not supported for ONNX models.".to_string(),
)),
}
}
pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
match self {
Self::GPT(model_ref) => model_ref.set_device(device),
Self::GPT2(model_ref) => model_ref.set_device(device),
Self::GPTNeo(model_ref) => model_ref.set_device(device),
Self::GPTJ(model_ref) => model_ref.set_device(device),
Self::XLNet(model_ref) => model_ref.set_device(device),
Self::Reformer(model_ref) => model_ref.set_device(device),
Self::T5(model_ref) => model_ref.set_device(device),
#[cfg(feature = "onnx")]
Self::ONNX(_) => Err(RustBertError::OrtError(
"Device assignment not supported for ONNX models.".to_string(),
)),
}
}
}
pub struct TextGenerationModel {
model: TextGenerationOption,
prefix: Option<String>,
prefix_length: Option<i64>,
min_length: i64,
max_length: Option<i64>,
}
impl TextGenerationModel {
pub fn new(
generation_config: TextGenerationConfig,
) -> Result<TextGenerationModel, RustBertError> {
let (prefix, min_length, max_length) =
TextGenerationModel::get_prefix_min_max_length(&generation_config);
let model = TextGenerationOption::new(generation_config)?;
let prefix_length = prefix
.as_ref()
.map(|prefix| model.get_tokenizer().tokenize(prefix).len() as i64);
Ok(TextGenerationModel {
model,
prefix,
prefix_length,
min_length,
max_length,
})
}
pub fn new_with_tokenizer(
generation_config: TextGenerationConfig,
tokenizer: TokenizerOption,
) -> Result<TextGenerationModel, RustBertError> {
let (prefix, min_length, max_length) =
TextGenerationModel::get_prefix_min_max_length(&generation_config);
let model = TextGenerationOption::new_with_tokenizer(generation_config, tokenizer)?;
let prefix_length = prefix
.as_ref()
.map(|prefix| model.get_tokenizer().tokenize(prefix).len() as i64);
Ok(TextGenerationModel {
model,
prefix,
prefix_length,
min_length,
max_length,
})
}
fn get_prefix_min_max_length(
generation_config: &TextGenerationConfig,
) -> (Option<String>, i64, Option<i64>) {
let prefix = match generation_config.model_type {
ModelType::XLNet => Some(
"In 1991, the remains of Russian Tsar Nicholas II and his family \
(except for Alexei and Maria) are discovered. \
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the \
remainder of the story. 1883 Western Siberia, \
a young Grigori Rasputin is asked by his father and a group of men to perform magic. \
Rasputin has a vision and denounces one of the men as a horse thief. Although his \
father initially slaps him for making such an accusation, Rasputin watches as the \
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of \
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, \
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
.to_string(),
),
_ => None,
};
let min_length = generation_config.min_length;
let max_length = generation_config.max_length;
(prefix, min_length, max_length)
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.model.get_tokenizer()
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.model.get_tokenizer_mut()
}
pub fn half(&mut self) -> Result<(), RustBertError> {
self.model.half()
}
pub fn float(&mut self) -> Result<(), RustBertError> {
self.model.float()
}
pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
self.model.set_device(device)
}
pub fn generate<'a, S>(
&self,
texts: &[S],
prefix: impl Into<Option<&'a str>>,
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<str> + Send + Sync,
{
let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
(Some(query_prefix), _) => (
Some(query_prefix),
Some(self.model.get_tokenizer().tokenize(query_prefix).len() as i64),
),
(None, Some(pipeline_prefix)) => (Some(pipeline_prefix.as_str()), self.prefix_length),
(None, None) => (None, None),
};
let generated_indices = match (prefix, prefix_length) {
(None, _) => self.model.generate_indices(Some(texts), None, None),
(Some(prefix), Some(prefix_length)) => {
let texts = texts
.as_ref()
.iter()
.map(|text| format!("{} {}", prefix, text.as_ref()))
.collect::<Vec<String>>();
self.model.generate_indices(
Some(&texts),
Some(self.min_length + prefix_length),
self.max_length.map(|max_length| max_length + prefix_length),
)
}
_ => Err(RustBertError::ValueError(
"Prefix length not defined but prefix provided!".to_string(),
)),
}?;
let mut output = Vec::with_capacity(generated_indices.len());
for generated_sequence in generated_indices {
output.push(self.model.get_tokenizer().decode(
&generated_sequence[prefix_length.unwrap_or(0) as usize..],
true,
true,
));
}
Ok(output)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = TextGenerationConfig::default();
let _: Box<dyn Send> = Box::new(TextGenerationModel::new(config));
}
}