use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use crate::{
completion::Message,
wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
};
#[cfg(not(target_family = "wasm"))]
pub type MemoryBackendError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[cfg(target_family = "wasm")]
pub type MemoryBackendError = Box<dyn std::error::Error + 'static>;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum MemoryError {
#[error("Memory backend error: {0}")]
Backend(MemoryBackendError),
#[error("Memory policy error: {0}")]
Policy(String),
#[error("Memory internal error: {0}")]
Internal(String),
}
impl MemoryError {
pub fn backend<E>(source: E) -> Self
where
E: Into<MemoryBackendError>,
{
Self::Backend(source.into())
}
}
pub trait ConversationMemory: WasmCompatSend + WasmCompatSync {
fn load<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>>;
fn append<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>>;
fn clear<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>>;
}
impl<M> ConversationMemory for Arc<M>
where
M: ConversationMemory + ?Sized,
{
fn load<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
(**self).load(conversation_id)
}
fn append<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
(**self).append(conversation_id, messages)
}
fn clear<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
(**self).clear(conversation_id)
}
}
impl<M> ConversationMemory for Box<M>
where
M: ConversationMemory + ?Sized,
{
fn load<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
(**self).load(conversation_id)
}
fn append<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
(**self).append(conversation_id, messages)
}
fn clear<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
(**self).clear(conversation_id)
}
}
pub trait MessageFilter:
Fn(Vec<Message>) -> Vec<Message> + WasmCompatSend + WasmCompatSync
{
}
impl<F> MessageFilter for F where
F: Fn(Vec<Message>) -> Vec<Message> + WasmCompatSend + WasmCompatSync
{
}
pub trait DemotionHook: WasmCompatSend + WasmCompatSync {
fn on_demote<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoopDemotionHook;
impl DemotionHook for NoopDemotionHook {
fn on_demote<'a>(
&'a self,
_conversation_id: &'a str,
_messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
Box::pin(async move { Ok(()) })
}
}
impl<H> DemotionHook for Arc<H>
where
H: DemotionHook + ?Sized,
{
fn on_demote<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
(**self).on_demote(conversation_id, messages)
}
}
pub trait Compactor: WasmCompatSend + WasmCompatSync {
type Artifact: Into<Message> + Clone + WasmCompatSend + WasmCompatSync + 'static;
fn compact<'a>(
&'a self,
conversation_id: &'a str,
evicted: &'a [Message],
carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>>;
}
impl<C> Compactor for Arc<C>
where
C: Compactor + ?Sized,
{
type Artifact = C::Artifact;
fn compact<'a>(
&'a self,
conversation_id: &'a str,
evicted: &'a [Message],
carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
(**self).compact(conversation_id, evicted, carry_over)
}
}
#[derive(Clone, Default)]
pub struct InMemoryConversationMemory {
inner: Arc<Mutex<HashMap<String, Vec<Message>>>>,
filter: Option<Arc<dyn MessageFilter>>,
}
impl InMemoryConversationMemory {
pub fn new() -> Self {
Self::default()
}
pub fn with_filter<F>(mut self, filter: F) -> Self
where
F: MessageFilter + 'static,
{
self.filter = Some(Arc::new(filter));
self
}
fn lock(
&self,
) -> Result<std::sync::MutexGuard<'_, HashMap<String, Vec<Message>>>, MemoryError> {
self.inner
.lock()
.map_err(|e| MemoryError::Internal(e.to_string()))
}
}
impl std::fmt::Debug for InMemoryConversationMemory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InMemoryConversationMemory")
.field("filter", &self.filter.as_ref().map(|_| "<filter>"))
.finish()
}
}
impl ConversationMemory for InMemoryConversationMemory {
fn load<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
Box::pin(async move {
let messages = {
let guard = self.lock()?;
guard.get(conversation_id).cloned().unwrap_or_default()
};
match &self.filter {
Some(filter) => Ok(filter(messages)),
None => Ok(messages),
}
})
}
fn append<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
Box::pin(async move {
let mut guard = self.lock()?;
guard
.entry(conversation_id.to_string())
.or_default()
.extend(messages);
Ok(())
})
}
fn clear<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
Box::pin(async move {
let mut guard = self.lock()?;
guard.remove(conversation_id);
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::completion::Message;
fn user(text: &str) -> Message {
Message::user(text)
}
fn assistant(text: &str) -> Message {
Message::assistant(text)
}
#[tokio::test]
async fn round_trip() {
let mem = InMemoryConversationMemory::new();
assert!(mem.load("c1").await.unwrap().is_empty());
mem.append("c1", vec![user("hello"), assistant("hi")])
.await
.unwrap();
let loaded = mem.load("c1").await.unwrap();
assert_eq!(loaded.len(), 2);
}
#[tokio::test]
async fn isolation_between_conversations() {
let mem = InMemoryConversationMemory::new();
mem.append("a", vec![user("hi a")]).await.unwrap();
mem.append("b", vec![user("hi b")]).await.unwrap();
assert_eq!(mem.load("a").await.unwrap().len(), 1);
assert_eq!(mem.load("b").await.unwrap().len(), 1);
}
#[tokio::test]
async fn clear_removes_history() {
let mem = InMemoryConversationMemory::new();
mem.append("c", vec![user("x")]).await.unwrap();
mem.clear("c").await.unwrap();
assert!(mem.load("c").await.unwrap().is_empty());
}
#[tokio::test]
async fn with_filter_transforms_loaded_messages() {
let mem = InMemoryConversationMemory::new()
.with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).collect());
mem.append(
"c",
vec![user("1"), assistant("2"), user("3"), assistant("4")],
)
.await
.unwrap();
let loaded = mem.load("c").await.unwrap();
assert_eq!(loaded.len(), 2, "filter should retain only 2 messages");
}
#[tokio::test]
async fn arc_conversation_memory_forwards_to_inner() {
let inner = Arc::new(InMemoryConversationMemory::new());
let mem: Arc<dyn ConversationMemory> = inner.clone();
mem.append("c", vec![user("hello")]).await.unwrap();
assert_eq!(inner.load("c").await.unwrap().len(), 1);
mem.clear("c").await.unwrap();
assert!(inner.load("c").await.unwrap().is_empty());
}
#[tokio::test]
async fn boxed_conversation_memory_forwards_to_inner() {
let mem: Box<dyn ConversationMemory> = Box::new(InMemoryConversationMemory::new());
mem.append("c", vec![user("hello")]).await.unwrap();
assert_eq!(mem.load("c").await.unwrap().len(), 1);
mem.clear("c").await.unwrap();
assert!(mem.load("c").await.unwrap().is_empty());
}
}