use crate::common::error::RustBertError;
use crate::gpt2::GPT2Generator;
use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use crate::resources::ResourceProvider;
use std::collections::HashMap;
use tch::{Device, Kind, Tensor};
use uuid::Uuid;
#[cfg(feature = "remote")]
use crate::{
gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
resources::RemoteResource,
};
pub struct ConversationConfig {
pub model_type: ModelType,
pub model_resource: ModelResource,
pub config_resource: Box<dyn ResourceProvider + Send>,
pub vocab_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
pub min_length: i64,
pub max_length: Option<i64>,
pub min_length_for_response: i64,
pub do_sample: bool,
pub early_stopping: bool,
pub num_beams: i64,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub length_penalty: f64,
pub no_repeat_ngram_size: i64,
pub num_return_sequences: i64,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub device: Device,
pub kind: Option<Kind>,
}
#[cfg(feature = "remote")]
impl Default for ConversationConfig {
fn default() -> ConversationConfig {
ConversationConfig {
model_type: ModelType::GPT2,
model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
))),
config_resource: Box::new(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
)),
vocab_resource: Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
))),
min_length: 0,
max_length: Some(1000),
min_length_for_response: 64,
do_sample: true,
early_stopping: false,
num_beams: 1,
temperature: 1.0,
top_k: 50,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
num_beam_groups: None,
diversity_penalty: None,
device: Device::cuda_if_available(),
kind: None,
}
}
}
impl From<ConversationConfig> for GenerateConfig {
fn from(config: ConversationConfig) -> GenerateConfig {
GenerateConfig {
model_type: config.model_type,
model_resource: config.model_resource,
config_resource: config.config_resource,
merges_resource: config.merges_resource,
vocab_resource: config.vocab_resource,
min_length: config.min_length,
max_length: config.max_length,
do_sample: config.do_sample,
early_stopping: config.early_stopping,
num_beams: config.num_beams,
temperature: config.temperature,
top_k: config.top_k,
top_p: config.top_p,
repetition_penalty: config.repetition_penalty,
length_penalty: config.length_penalty,
no_repeat_ngram_size: config.no_repeat_ngram_size,
num_return_sequences: config.num_return_sequences,
num_beam_groups: config.num_beam_groups,
diversity_penalty: config.diversity_penalty,
device: config.device,
kind: config.kind,
}
}
}
#[derive(Debug, Clone)]
pub struct Conversation {
pub past_user_inputs: Vec<String>,
pub generated_responses: Vec<String>,
pub new_user_input: Option<String>,
pub history: Vec<Vec<i64>>,
}
impl Conversation {
pub fn new(text: &str) -> Conversation {
Conversation {
past_user_inputs: vec![],
generated_responses: vec![],
new_user_input: Some(text.to_string()),
history: vec![],
}
}
pub fn new_empty() -> Conversation {
Conversation {
past_user_inputs: vec![],
generated_responses: vec![],
new_user_input: None,
history: vec![],
}
}
pub fn add_user_input(&mut self, text: &str) -> Result<(), RustBertError> {
if self.new_user_input.is_some() {
Err(RustBertError::ValueError(
"User input already provided for this conversation".into(),
))
} else {
self.new_user_input = Some(text.to_string());
Ok(())
}
}
pub fn add_user_input_with_overwrite(&mut self, text: &str) -> Option<String> {
let old_user_input = if self.new_user_input.is_some() {
self.new_user_input.clone()
} else {
None
};
self.new_user_input = Some(text.to_string());
old_user_input
}
pub fn contains_new_input(&self) -> bool {
self.new_user_input.is_some()
}
pub fn mark_processed(&mut self) {
if self.new_user_input.is_some() {
self.past_user_inputs
.push(self.new_user_input.clone().unwrap());
self.new_user_input = None;
}
}
pub fn get_last_input(&self) -> Option<&str> {
if self.new_user_input.is_some() {
Some(self.new_user_input.as_ref().unwrap().as_str())
} else if !self.past_user_inputs.is_empty() {
Some(self.past_user_inputs.last().unwrap().as_str())
} else {
None
}
}
pub fn get_last_response(&self) -> Option<&str> {
if !self.generated_responses.is_empty() {
Some(self.generated_responses.last().unwrap().as_str())
} else {
None
}
}
fn append(&mut self, text: &str, ids: &[i64]) {
match &self.new_user_input {
Some(_) => {
self.mark_processed();
if self.past_user_inputs.len() >= self.generated_responses.len() {
self.generated_responses.push(text.to_string());
} else {
let _ = self.add_user_input(text);
}
}
None => {
let _ = self.add_user_input(text);
}
}
self.history.push(ids.to_vec());
}
pub fn load_from_history<S, SI>(&mut self, texts: &[S], ids: &[SI])
where
S: AsRef<str>,
SI: AsRef<[i64]>,
{
for (round_text, round_ids) in texts.iter().zip(ids.iter()) {
self.append(round_text.as_ref(), round_ids.as_ref());
}
if texts.len() / 2 == 1 {
self.history.pop();
}
}
}
#[derive(Debug)]
pub struct ConversationManager {
conversations: HashMap<Uuid, Conversation>,
}
impl ConversationManager {
pub fn new() -> ConversationManager {
ConversationManager {
conversations: HashMap::new(),
}
}
pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
let mut active_uuid = vec![];
let mut active_conversations = vec![];
for (uuid, conversation) in self.conversations.iter_mut() {
if conversation.new_user_input.is_some() {
active_uuid.push(uuid);
active_conversations.push(conversation)
}
}
(active_uuid, active_conversations)
}
pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
self.conversations.get_mut(uuid)
}
pub fn get_all(&mut self) -> HashMap<&Uuid, &Conversation> {
let mut output = HashMap::with_capacity(self.conversations.len());
for (uuid, conversation) in self.conversations.iter() {
output.insert(uuid, conversation);
}
output
}
pub fn create(&mut self, text: &str) -> Uuid {
let conversation = Conversation::new(text);
self.add(conversation)
}
pub fn create_empty(&mut self) -> Uuid {
let conversation = Conversation::new_empty();
self.add(conversation)
}
pub fn add(&mut self, conversation: Conversation) -> Uuid {
let mut uuid = Uuid::new_v4();
while self.conversations.contains_key(&uuid) {
uuid = Uuid::new_v4();
}
self.conversations.insert(uuid, conversation);
uuid
}
pub fn remove(&mut self, uuid: &Uuid) -> Option<Conversation> {
self.conversations.remove(uuid)
}
pub fn clear(&mut self) -> HashMap<Uuid, Conversation> {
let mut output = HashMap::with_capacity(self.conversations.len());
for (uuid, conversation) in self.conversations.iter() {
output.insert(*uuid, conversation.clone());
}
self.conversations = HashMap::new();
output
}
}
impl Default for ConversationManager {
fn default() -> Self {
Self::new()
}
}
pub enum ConversationOption {
GPT2(GPT2Generator),
}
impl ConversationOption {
pub fn new(config: ConversationConfig) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)),
_ => Err(RustBertError::InvalidConfigurationError(
"GPT2 is currently the only supported model for conversation generation"
.to_string(),
)),
}
}
pub fn new_with_tokenizer(
config: ConversationConfig,
tokenizer: TokenizerOption,
) -> Result<Self, RustBertError> {
match config.model_type {
ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new_with_tokenizer(
config.into(),
tokenizer,
)?)),
_ => Err(RustBertError::InvalidConfigurationError(
"GPT2 is currently the only supported model for conversation generation"
.to_string(),
)),
}
}
pub fn get_eos_id(&self) -> Result<i64, RustBertError> {
match self {
Self::GPT2(model_ref) => {
Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
}
}
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
}
}
pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
}
}
pub fn model_type(&self) -> ModelType {
match *self {
Self::GPT2(_) => ModelType::GPT2,
}
}
pub fn generate_from_ids_and_past(
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
) -> Result<Vec<Vec<i64>>, RustBertError> {
Ok(match *self {
Self::GPT2(ref model) => model
.generate_from_ids_and_past(input_ids, attention_mask, None)?
.into_iter()
.map(|output| output.indices)
.collect(),
})
}
}
pub struct ConversationModel {
model: ConversationOption,
eos_token_id: i64,
max_allowed_context_length: Option<i64>,
device: Device,
}
impl ConversationModel {
pub fn new(
conversation_config: ConversationConfig,
) -> Result<ConversationModel, RustBertError> {
let max_allowed_length = conversation_config
.max_length
.map(|max_length| max_length - conversation_config.min_length_for_response);
let device = conversation_config.device;
let model = ConversationOption::new(conversation_config)?;
let eos_token_id = model.get_eos_id()?;
Ok(ConversationModel {
model,
eos_token_id,
max_allowed_context_length: max_allowed_length,
device,
})
}
pub fn new_with_tokenizer(
conversation_config: ConversationConfig,
tokenizer: TokenizerOption,
) -> Result<ConversationModel, RustBertError> {
let max_allowed_length = conversation_config
.max_length
.map(|max_length| max_length - conversation_config.min_length_for_response);
let device = conversation_config.device;
let model = ConversationOption::new_with_tokenizer(conversation_config, tokenizer)?;
let eos_token_id = model.get_eos_id()?;
Ok(ConversationModel {
model,
eos_token_id,
max_allowed_context_length: max_allowed_length,
device,
})
}
pub fn generate_responses<'a>(
&self,
conversation_manager: &'a mut ConversationManager,
) -> Result<HashMap<&'a Uuid, &'a str>, RustBertError> {
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
let updated_conversations = if !active_uuid.is_empty() {
let texts = active_conversations
.iter()
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
.collect::<Vec<&str>>();
let history = active_conversations
.iter()
.map(|c| c.history.iter().flatten().copied().collect())
.collect::<Vec<Vec<i64>>>();
let prompt_ids = self.encode_prompts(texts.as_ref());
let (input_tensor, attention_mask) =
self.concat_input_history(prompt_ids.as_ref(), history);
let input_length = *input_tensor.size().last().unwrap() as usize;
let mut generated = self
.model
.generate_from_ids_and_past(input_tensor, Some(attention_mask))?;
let removed_padding_quantities = self.clean_padding_indices(&mut generated);
let mut output = HashMap::with_capacity(active_uuid.len());
for (
((conversation, (generated_sequence, conversation_promp_ids)), uuid),
removed_padding,
) in active_conversations
.into_iter()
.zip(generated.into_iter().zip(prompt_ids.into_iter()))
.zip(active_uuid.into_iter())
.zip(removed_padding_quantities.into_iter())
{
let generated_response = &generated_sequence[input_length - removed_padding.0..];
conversation
.generated_responses
.push(
self.model
.get_tokenizer()
.decode(generated_response, true, true),
);
conversation.history.push(conversation_promp_ids);
conversation.history.push(generated_response.to_vec());
conversation.mark_processed();
output.insert(uuid, conversation.get_last_response().unwrap());
}
output
} else {
HashMap::new()
};
Ok(updated_conversations)
}
fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) -> Vec<(usize, usize)> {
let pad_token = self
.model
.get_tokenizer()
.get_pad_id()
.unwrap_or(self.eos_token_id);
let mut removed_tokens = Vec::with_capacity(model_output.len());
for sequence_history in model_output {
let index_end = sequence_history
.iter()
.rev()
.position(|&r| r != pad_token)
.unwrap();
let index_start = sequence_history
.iter()
.position(|&r| r != pad_token)
.unwrap();
if index_end > 0 {
sequence_history.drain(sequence_history.len() - index_end + 1..);
}
sequence_history.drain(..index_start);
removed_tokens.push((index_start, index_end));
}
removed_tokens
}
fn concat_input_history(
&self,
inputs: &[Vec<i64>],
history: Vec<Vec<i64>>,
) -> (Tensor, Tensor) {
let pad_token = self
.model
.get_tokenizer()
.get_pad_id()
.unwrap_or(self.eos_token_id);
assert_eq!(
inputs.len(),
history.len(),
"Length of inputs should equal length of history"
);
let mut concatenated_inputs = Vec::with_capacity(inputs.len());
for (input, history) in inputs.iter().zip(history.iter()) {
let mut concatenated_element = Vec::with_capacity(input.len() + history.len());
concatenated_element.extend_from_slice(history);
concatenated_element.extend_from_slice(input);
concatenated_inputs.push(concatenated_element);
}
let truncated_concatenated_inputs = concatenated_inputs
.iter()
.map(|input| match self.max_allowed_context_length {
Some(max_allowed_context_length)
if input.len() > max_allowed_context_length as usize =>
{
let start = self.get_truncated_input_index(
input,
max_allowed_context_length as usize,
pad_token,
);
&input[start..]
}
_ => input.as_slice(),
})
.collect::<Vec<&[i64]>>();
let max_len = truncated_concatenated_inputs
.iter()
.map(|input| input.len())
.max()
.unwrap();
let attention_mask = Tensor::ones(
[inputs.len() as i64, max_len as i64],
(Kind::Int8, self.device),
);
let concatenated_inputs = truncated_concatenated_inputs
.into_iter()
.enumerate()
.map(|(input_idx, input)| {
let _ = attention_mask
.get(input_idx as i64)
.slice(0, 0, (max_len - input.len()) as i64, 1)
.fill_(0);
let mut padded_input = vec![pad_token; max_len - input.len()];
padded_input.extend(input);
padded_input
})
.map(|tokens| Tensor::from_slice(&tokens).to(self.device))
.collect::<Vec<Tensor>>();
(Tensor::stack(&concatenated_inputs, 0), attention_mask)
}
fn get_truncated_input_index(
&self,
history: &[i64],
max_length: usize,
pad_token: i64,
) -> usize {
let start_length = history.len();
let eos_indices: Vec<usize> = history
.iter()
.enumerate()
.filter(|(i, &e)| {
(e == pad_token)
& (*i != start_length - 1)
& ((start_length as isize - max_length as isize - *i as isize) < 0)
})
.map(|(i, _)| i + 1)
.collect();
*eos_indices.first().unwrap_or(&(start_length - max_length))
}
pub fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
let tokens = self.model.get_tokenizer().tokenize_list(texts);
tokens
.into_iter()
.map(|prompt_tokens| {
self.model
.get_tokenizer()
.convert_tokens_to_ids(&prompt_tokens)
})
.map(|mut tokens| {
if let Some(max_allowed_context_length) = self.max_allowed_context_length {
tokens.truncate(max_allowed_context_length as usize - 1);
}
tokens.push(self.eos_token_id);
tokens
})
.collect::<Vec<Vec<i64>>>()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = ConversationConfig::default();
let _: Box<dyn Send> = Box::new(ConversationModel::new(config));
}
}