use std::fmt::{Debug, Formatter};
use tracing::trace;
use crate::chains::Message;
use crate::models::{ChatInput, Role};
use crate::SapiensConfig;
pub trait ChatEntryFormatter {
fn format(&self, entry: &ChatEntry) -> String;
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("The prompt is too long")]
PromptTooLong,
}
#[derive(Clone)]
pub struct ChatEntry {
pub role: Role,
pub msg: String,
}
impl Debug for ChatEntry {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}]: {}", self.role, self.msg)
}
}
#[derive(Clone)]
pub struct ChatHistory {
config: SapiensConfig,
max_token: usize,
context: Vec<ChatEntry>,
examples: Vec<(ChatEntry, ChatEntry)>,
chitchat: Vec<ChatEntry>,
}
impl Debug for ChatHistory {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChatHistory")
.field("config", &self.config)
.field("max_token", &self.max_token)
.field("context", &self.context)
.field("examples", &self.examples)
.field("chitchat", &self.chitchat)
.finish()
}
}
impl ChatHistory {
pub fn new(config: SapiensConfig, max_token: usize) -> Self {
Self {
config,
max_token,
context: vec![],
examples: vec![],
chitchat: vec![],
}
}
pub fn set_context(&mut self, context: Vec<ChatEntry>) {
self.context = context;
}
pub async fn add_example(&mut self, user: String, bot: String) {
let msg_user = ChatEntry {
role: Role::User,
msg: user,
};
let msg_bot = ChatEntry {
role: Role::Assistant,
msg: bot,
};
self.examples.push((msg_user, msg_bot));
}
pub async fn add_chitchat(&mut self, entry: ChatEntry) {
if let Some(last) = self.chitchat.last() {
if last.role == entry.role {
self.chitchat.pop();
}
}
self.chitchat.push(entry);
}
pub(crate) fn make_input(&self) -> ChatInput {
ChatInput {
context: self.context.clone(),
examples: self.examples.clone(),
chat: self.chitchat.clone(),
}
}
pub(crate) fn is_chitchat_empty(&self) -> bool {
self.chitchat.is_empty()
}
pub async fn purge(&mut self) -> Result<usize, Error> {
if self.chitchat.is_empty() {
return Ok(0);
}
trace!(
max_token = self.max_token,
min_tokens_for_completion = self.config.min_tokens_for_completion,
"purging history"
);
while !self.examples.is_empty() {
let input = self.make_input();
let num_tokens = self.config.model.num_tokens(input).await;
trace!(
max_token = self.max_token,
min_tokens_for_completion = self.config.min_tokens_for_completion,
len = self.examples.len(),
num_tokens,
"purging history - examples"
);
if num_tokens <= self.max_token - self.config.min_tokens_for_completion {
return Ok(self.chitchat.len());
}
self.examples.remove(0);
}
while self.chitchat.len() > 1 {
let input = self.make_input();
let num_tokens = self.config.model.num_tokens(input).await;
trace!(
max_token = self.max_token,
min_tokens_for_completion = self.config.min_tokens_for_completion,
len = self.chitchat.len(),
num_tokens,
"purging history - loop"
);
if num_tokens <= self.max_token - self.config.min_tokens_for_completion {
return Ok(self.chitchat.len());
}
self.chitchat.remove(0);
}
let input = self.make_input();
let num_tokens = self.config.model.num_tokens(input).await;
if num_tokens <= self.max_token - self.config.min_tokens_for_completion {
return Ok(self.chitchat.len());
}
Err(Error::PromptTooLong)
}
pub fn iter(&self) -> impl Iterator<Item = &ChatEntry> {
self.context
.iter()
.chain(self.examples.iter().flat_map(|(a, b)| vec![a, b]))
.chain(self.chitchat.iter())
}
pub fn format<T>(&self, formatter: &T) -> Vec<String>
where
T: ChatEntryFormatter + ?Sized,
{
self.iter()
.map(|msg| formatter.format(msg))
.collect::<Vec<_>>()
}
}
impl From<&ChatHistory> for Vec<ChatEntry> {
fn from(val: &ChatHistory) -> Self {
val.iter().cloned().collect()
}
}
pub struct ContextDump {
pub messages: Vec<Message>,
}
pub trait MessageFormatter {
fn format(&self, msg: &Message) -> String;
}
impl ContextDump {
pub fn format<T>(&self, formatter: &T) -> Vec<String>
where
T: MessageFormatter + ?Sized,
{
self.messages
.iter()
.map(|msg| formatter.format(msg))
.collect::<Vec<_>>()
}
}