use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::fmt;
use crate::tokens::{Tokenizer, TokenizerError};
use super::{StringTemplate, StringTemplateError};
use crate::Parameters;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum ChatRole {
User,
Assistant,
System,
Other(String),
}
impl fmt::Display for ChatRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ChatRole::User => write!(f, "User"),
ChatRole::Assistant => write!(f, "Assistant"),
ChatRole::System => write!(f, "System"),
ChatRole::Other(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage<Body> {
role: ChatRole,
body: Body,
}
impl<Body> ChatMessage<Body> {
pub fn new(role: ChatRole, body: Body) -> Self {
Self { role, body }
}
pub fn assistant(body: Body) -> Self {
Self::new(ChatRole::Assistant, body)
}
pub fn user(body: Body) -> Self {
Self::new(ChatRole::User, body)
}
pub fn system(body: Body) -> Self {
Self::new(ChatRole::System, body)
}
pub fn map<U, F: FnOnce(&Body) -> U>(&self, f: F) -> ChatMessage<U> {
let role = self.role.clone();
ChatMessage {
role,
body: f(&self.body),
}
}
pub fn try_map<U, E, F: Fn(&Body) -> Result<U, E>>(&self, f: F) -> Result<ChatMessage<U>, E> {
let body = f(&self.body)?;
let role = self.role.clone();
Ok(ChatMessage { role, body })
}
pub fn role(&self) -> &ChatRole {
&self.role
}
pub fn body(&self) -> &Body {
&self.body
}
}
impl<T: fmt::Display> fmt::Display for ChatMessage<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.role, self.body)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessageCollection<Body> {
messages: VecDeque<ChatMessage<Body>>,
}
impl<Body> ChatMessageCollection<Body> {
pub fn new() -> Self {
ChatMessageCollection {
messages: VecDeque::new(),
}
}
pub fn for_vector(messages: Vec<ChatMessage<Body>>) -> Self {
ChatMessageCollection {
messages: messages.into(),
}
}
pub fn with_system(mut self, body: Body) -> Self {
self.add_message(ChatMessage::system(body));
self
}
pub fn with_user(mut self, body: Body) -> Self {
self.add_message(ChatMessage::user(body));
self
}
pub fn with_assistant(mut self, body: Body) -> Self {
self.add_message(ChatMessage::assistant(body));
self
}
pub fn append(&mut self, other: ChatMessageCollection<Body>) {
self.messages.extend(other.messages);
}
pub fn add_message(&mut self, message: ChatMessage<Body>) {
self.messages.push_back(message);
}
pub fn remove_first_message(&mut self) -> Option<ChatMessage<Body>> {
self.messages.pop_front()
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub(crate) fn extract_last_body(&self) -> Option<&Body> {
self.messages.back().map(|x| &x.body)
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn get_message(&self, index: usize) -> Option<&ChatMessage<Body>> {
self.messages.get(index)
}
pub fn iter(&self) -> std::collections::vec_deque::Iter<'_, ChatMessage<Body>> {
self.messages.iter()
}
pub fn map<U, F>(&self, f: F) -> ChatMessageCollection<U>
where
F: FnMut(&ChatMessage<Body>) -> ChatMessage<U>,
{
let mapped_messages: VecDeque<ChatMessage<U>> = self.messages.iter().map(f).collect();
ChatMessageCollection {
messages: mapped_messages,
}
}
pub fn try_map<U, E, F: Fn(&Body) -> Result<U, E>>(
&self,
f: F,
) -> Result<ChatMessageCollection<U>, E> {
let mut mapped_messages = VecDeque::new();
for msg in self.messages.iter() {
let mapped_msg = msg.try_map(|body| f(body))?;
mapped_messages.push_back(mapped_msg);
}
Ok(ChatMessageCollection {
messages: mapped_messages,
})
}
pub fn trim_to_max_messages(&mut self, max_number_of_messages: usize) {
while self.len() > max_number_of_messages {
self.messages.pop_front();
}
}
}
impl<Body> Default for ChatMessageCollection<Body> {
fn default() -> Self {
ChatMessageCollection::new()
}
}
impl<T: fmt::Display> fmt::Display for ChatMessageCollection<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for message in self.messages.iter() {
writeln!(f, "{}", message)?;
}
Ok(())
}
}
impl ChatMessageCollection<String> {
pub fn trim_context<Tok>(
&mut self,
tokenizer: &Tok,
max_tokens: i32,
) -> Result<(), TokenizerError>
where
Tok: Tokenizer,
{
let mut total_tokens: i32 = 0;
while let Some(msg) = self.messages.back() {
let tokens = tokenizer.tokenize_str(&msg.body)?;
total_tokens += tokens.len() as i32;
if total_tokens > max_tokens {
self.messages.pop_back();
} else {
break;
}
}
Ok(())
}
pub fn with_user_template(
self,
body: &str,
parameters: &Parameters,
) -> Result<Self, StringTemplateError> {
match StringTemplate::tera(body).format(parameters) {
Err(e) => Err(e),
Ok(templated_body) => Ok(self.with_user(templated_body)),
}
}
pub fn with_system_template(
self,
body: &str,
parameters: &Parameters,
) -> Result<Self, StringTemplateError> {
match StringTemplate::tera(body).format(parameters) {
Err(e) => Err(e),
Ok(templated_body) => Ok(self.with_system(templated_body)),
}
}
pub fn with_assistant_template(
self,
body: &str,
parameters: &Parameters,
) -> Result<Self, StringTemplateError> {
match StringTemplate::tera(body).format(parameters) {
Err(e) => Err(e),
Ok(templated_body) => Ok(self.with_assistant(templated_body)),
}
}
}
impl ChatMessageCollection<StringTemplate> {
pub fn with_user_template(self, body: &str) -> Self {
self.with_user(StringTemplate::tera(body))
}
pub fn with_system_template(self, body: &str) -> Self {
self.with_system(StringTemplate::tera(body))
}
pub fn with_assistant_template(self, body: &str) -> Self {
self.with_assistant(StringTemplate::tera(body))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_map() {
let msg = ChatMessage::new(ChatRole::Assistant, "Hello!");
let mapped_msg = msg.map(|body| body.to_uppercase());
assert_eq!(mapped_msg.body, "HELLO!");
assert_eq!(mapped_msg.role, ChatRole::Assistant);
}
#[test]
fn test_chat_message_list() {
let mut chat_message_list = ChatMessageCollection::new();
assert_eq!(chat_message_list.len(), 0);
chat_message_list.add_message(ChatMessage::new(ChatRole::User, "Hello!"));
chat_message_list.add_message(ChatMessage::new(ChatRole::Assistant, "Hi there!"));
assert_eq!(chat_message_list.len(), 2);
assert_eq!(chat_message_list.get_message(0).unwrap().body, "Hello!");
assert_eq!(chat_message_list.get_message(1).unwrap().body, "Hi there!");
chat_message_list.remove_first_message();
assert_eq!(chat_message_list.len(), 1);
}
#[test]
fn test_chat_message_list_map() {
let mut chat_message_list = ChatMessageCollection::new();
chat_message_list.add_message(ChatMessage::new(ChatRole::User, "Hello!"));
chat_message_list.add_message(ChatMessage::new(ChatRole::Assistant, "Hi there!"));
let mapped_list = chat_message_list
.map(|msg| ChatMessage::new(msg.role.clone(), format!("{} (mapped)", msg.body)));
assert_eq!(mapped_list.get_message(0).unwrap().body, "Hello! (mapped)");
assert_eq!(
mapped_list.get_message(1).unwrap().body,
"Hi there! (mapped)"
);
}
}