use crate::memory::MemoryError;
use crate::memory::Summarizer;
use juncture_core::state::messages::Message;
use juncture_core::store::{Store, StoreError};
use serde::Serialize;
use std::fmt;
use std::sync::Arc;
pub struct ConversationMemory<S: Store> {
store: Arc<S>,
summarizer: Option<Arc<dyn Summarizer>>,
namespace: String,
max_messages: usize,
}
#[derive(Serialize)]
struct SummaryData<'a> {
summary: &'a str,
stored_at: String,
}
impl<S: Store> ConversationMemory<S> {
#[must_use]
pub fn new(store: Arc<S>, namespace: String) -> Self {
Self {
store,
summarizer: None,
namespace,
max_messages: 100,
}
}
#[must_use]
pub fn with_summarizer(mut self, summarizer: Arc<dyn Summarizer>) -> Self {
self.summarizer = Some(summarizer);
self
}
#[must_use]
pub fn with_max_messages(mut self, max: usize) -> Self {
assert!(max > 0, "max_messages must be greater than 0");
self.max_messages = max;
self
}
#[must_use]
pub const fn should_summarize(&self, messages: &[Message]) -> bool {
messages.len() > self.max_messages
}
pub async fn summarize_messages(
&self,
thread_id: &str,
messages: &[Message],
) -> Result<Option<String>, MemoryError> {
if !self.should_summarize(messages) {
return Ok(None);
}
let Some(summarizer) = &self.summarizer else {
return Ok(None);
};
let num_to_summarize = messages.len() - self.max_messages / 2;
let text_to_summarize = messages
.iter()
.take(num_to_summarize)
.map(|msg| format!("[{:?}]: {}", msg.role, msg.content_text()))
.collect::<Vec<_>>()
.join("\n\n");
let summary = summarizer.summarize(&text_to_summarize).await?;
self.store_summary(thread_id, &summary).await?;
Ok(Some(summary))
}
pub async fn store_summary(&self, thread_id: &str, summary: &str) -> Result<(), MemoryError> {
let key = format!("summary:{thread_id}");
let data = SummaryData {
summary,
stored_at: chrono::Utc::now().to_rfc3339(),
};
let value = serde_json::to_value(data).map_err(|e| {
MemoryError::StoreError(format!("failed to serialize summary data: {e}"))
})?;
self.store
.put(&self.namespace, &key, value, None)
.await
.map_err(|e| match e {
StoreError::Serialize(e) => {
MemoryError::StoreError(format!("serialization error: {e}"))
}
StoreError::Storage(e) => MemoryError::StoreError(format!("storage error: {e}")),
_ => MemoryError::StoreError(format!("store error: {e}")),
})?;
Ok(())
}
pub async fn get_summary(&self, thread_id: &str) -> Result<Option<String>, MemoryError> {
let key = format!("summary:{thread_id}");
let item = self
.store
.get(&self.namespace, &key)
.await
.map_err(|e| match e {
StoreError::Serialize(e) => {
MemoryError::StoreError(format!("serialization error: {e}"))
}
StoreError::Storage(e) => MemoryError::StoreError(format!("storage error: {e}")),
_ => MemoryError::StoreError(format!("store error: {e}")),
})?;
match item {
Some(item) => {
let summary = item
.value
.get("summary")
.and_then(|v| v.as_str())
.ok_or_else(|| {
MemoryError::StoreError(
"invalid summary data: missing 'summary' field".to_string(),
)
})?;
Ok(Some(summary.to_string()))
}
None => Ok(None),
}
}
}
impl<S: Store> fmt::Debug for ConversationMemory<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConversationMemory")
.field("store", &"<Store>")
.field("summarizer", &self.summarizer.is_some())
.field("namespace", &self.namespace)
.field("max_messages", &self.max_messages)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use juncture_core::store::MemoryStore;
#[test]
fn test_conversation_memory_new() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string());
assert_eq!(memory.namespace, "test");
assert_eq!(memory.max_messages, 100);
assert!(memory.summarizer.is_none());
}
#[test]
fn test_conversation_memory_builder() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string()).with_max_messages(50);
assert_eq!(memory.max_messages, 50);
}
#[test]
#[should_panic(expected = "max_messages must be greater than 0")]
fn test_conversation_memory_invalid_max() {
let store = Arc::new(MemoryStore::new());
let _ = ConversationMemory::new(store, "test".to_string()).with_max_messages(0);
}
#[tokio::test]
async fn test_should_summarize_below_threshold() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string()).with_max_messages(100);
let messages = vec![Message::human("Hello"); 50];
assert!(!memory.should_summarize(&messages));
}
#[tokio::test]
async fn test_should_summarize_above_threshold() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string()).with_max_messages(100);
let messages = vec![Message::human("Hello"); 101];
assert!(memory.should_summarize(&messages));
}
#[tokio::test]
async fn test_store_and_get_summary() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string());
memory
.store_summary("thread_1", "This is a summary")
.await
.expect("store_summary failed");
let retrieved = memory
.get_summary("thread_1")
.await
.expect("get_summary failed");
assert_eq!(retrieved, Some("This is a summary".to_string()));
}
#[tokio::test]
async fn test_get_nonexistent_summary() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string());
let result = memory
.get_summary("nonexistent")
.await
.expect("get_summary failed");
assert_eq!(result, None);
}
#[tokio::test]
async fn test_summarize_below_threshold() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string()).with_max_messages(100);
let messages = vec![Message::human("Hello"); 50];
let result = memory
.summarize_messages("thread_1", &messages)
.await
.expect("summarize_messages failed");
assert_eq!(result, None);
}
#[tokio::test]
async fn test_summarize_no_summarizer_configured() {
let store = Arc::new(MemoryStore::new());
let memory = ConversationMemory::new(store, "test".to_string()).with_max_messages(100);
let messages = vec![Message::human("Hello"); 101];
let result = memory
.summarize_messages("thread_1", &messages)
.await
.expect("summarize_messages failed");
assert_eq!(result, None);
}
}