use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use super::{MemoryProvider, MemoryType, MessageEvent};
use crate::{chat::ChatMessage, error::LLMError};
const DEFAULT_REACTIVE_CAPACITY: usize = 1000;
#[derive(Clone)]
pub struct SharedMemory<T: MemoryProvider> {
inner: Arc<RwLock<T>>,
event_sender: Option<broadcast::Sender<MessageEvent>>,
memory_type: MemoryType,
}
impl<T: MemoryProvider> SharedMemory<T> {
pub fn new(provider: T) -> Self {
let memory_type = provider.memory_type();
Self {
inner: Arc::new(RwLock::new(provider)),
event_sender: None,
memory_type,
}
}
pub fn new_reactive_with_capacity(provider: T, capacity: usize) -> Self {
let memory_type = provider.memory_type();
let (sender, _) = broadcast::channel(capacity);
Self {
inner: Arc::new(RwLock::new(provider)),
event_sender: Some(sender),
memory_type,
}
}
pub fn new_reactive(provider: T) -> Self {
Self::new_reactive_with_capacity(provider, DEFAULT_REACTIVE_CAPACITY)
}
pub fn subscribe(&self) -> Option<broadcast::Receiver<MessageEvent>> {
self.event_sender.as_ref().map(|sender| sender.subscribe())
}
}
#[async_trait]
impl<T: MemoryProvider> MemoryProvider for SharedMemory<T> {
async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
let mut guard = self.inner.write().await;
guard.remember(message).await
}
async fn recall(
&self,
query: &str,
limit: Option<usize>,
) -> Result<Vec<ChatMessage>, LLMError> {
let guard = self.inner.read().await;
guard.recall(query, limit).await
}
async fn clear(&mut self) -> Result<(), LLMError> {
let mut guard = self.inner.write().await;
guard.clear().await
}
fn memory_type(&self) -> MemoryType {
self.memory_type.clone()
}
fn size(&self) -> usize {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let guard = self.inner.read().await;
guard.size()
})
})
}
fn needs_summary(&self) -> bool {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let guard = self.inner.read().await;
guard.needs_summary()
})
})
}
fn mark_for_summary(&mut self) {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let mut guard = self.inner.write().await;
guard.mark_for_summary();
})
})
}
fn replace_with_summary(&mut self, summary: String) {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let mut guard = self.inner.write().await;
guard.replace_with_summary(summary);
})
})
}
fn get_event_receiver(&self) -> Option<broadcast::Receiver<MessageEvent>> {
self.event_sender.as_ref().map(|sender| sender.subscribe())
}
async fn remember_with_role(
&mut self,
message: &ChatMessage,
role: String,
) -> Result<(), LLMError> {
let mut guard = self.inner.write().await;
guard.remember(message).await?;
if let Some(sender) = &self.event_sender {
let mut msg = message.clone();
msg.content = msg.content.replace(&format!("[{role}]"), "");
let event = MessageEvent { role, msg };
let _ = sender.send(event);
}
Ok(())
}
}