mod types;
pub use types::*;
use async_trait::async_trait;
use crate::error::{Result, RustAgentsError};
use crate::harness::message::Message;
use crate::harness::store::Store;
impl InMemoryChatHistory {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl ChatHistory for InMemoryChatHistory {
async fn messages(&self, thread_id: &str) -> Result<Vec<Message>> {
let threads = self
.threads
.lock()
.map_err(|e| RustAgentsError::Memory(format!("chat history lock poisoned: {e}")))?;
Ok(threads.get(thread_id).cloned().unwrap_or_default())
}
async fn append(&self, thread_id: &str, message: Message) -> Result<()> {
let mut threads = self
.threads
.lock()
.map_err(|e| RustAgentsError::Memory(format!("chat history lock poisoned: {e}")))?;
threads
.entry(thread_id.to_string())
.or_default()
.push(message);
Ok(())
}
async fn clear(&self, thread_id: &str) -> Result<()> {
let mut threads = self
.threads
.lock()
.map_err(|e| RustAgentsError::Memory(format!("chat history lock poisoned: {e}")))?;
threads.remove(thread_id);
Ok(())
}
}
impl<S: Store> StoreChatHistory<S> {
pub const NAMESPACE: &'static str = "chat_history";
pub fn new(store: S) -> Self {
Self { store }
}
pub fn store(&self) -> &S {
&self.store
}
}
#[async_trait]
impl<S: Store> ChatHistory for StoreChatHistory<S> {
async fn messages(&self, thread_id: &str) -> Result<Vec<Message>> {
match self.store.get(Self::NAMESPACE, thread_id).await? {
Some(value) => {
let messages: Vec<Message> = serde_json::from_value(value)?;
Ok(messages)
}
None => Ok(Vec::new()),
}
}
async fn append(&self, thread_id: &str, message: Message) -> Result<()> {
let mut messages = self.messages(thread_id).await?;
messages.push(message);
let value = serde_json::to_value(&messages)?;
self.store.put(Self::NAMESPACE, thread_id, value).await
}
async fn clear(&self, thread_id: &str) -> Result<()> {
self.store.delete(Self::NAMESPACE, thread_id).await
}
}
impl<H: ChatHistory> ShortTermMemory<H> {
pub fn new(history: H, thread_id: impl Into<String>) -> Self {
Self {
history,
thread_id: thread_id.into(),
trim: None,
}
}
pub fn with_trim(
mut self,
trim: impl Fn(Vec<Message>) -> Vec<Message> + Send + Sync + 'static,
) -> Self {
self.trim = Some(Box::new(trim));
self
}
pub fn thread_id(&self) -> &str {
&self.thread_id
}
pub async fn load(&self) -> Result<Vec<Message>> {
let messages = self.history.messages(&self.thread_id).await?;
Ok(self.apply_trim(messages))
}
pub async fn append(&self, message: Message) -> Result<()> {
self.history.append(&self.thread_id, message).await
}
pub async fn save(&self, messages: Vec<Message>) -> Result<()> {
let trimmed = self.apply_trim(messages);
self.history.clear(&self.thread_id).await?;
for message in trimmed {
self.history.append(&self.thread_id, message).await?;
}
Ok(())
}
pub async fn clear(&self) -> Result<()> {
self.history.clear(&self.thread_id).await
}
fn apply_trim(&self, messages: Vec<Message>) -> Vec<Message> {
match &self.trim {
Some(trim) => trim(messages),
None => messages,
}
}
}
#[cfg(test)]
mod test;