use std::collections::BTreeMap;
use std::sync::Mutex;
use crate::error::MemoryError;
use crate::types::{ConversationId, MessageId};
#[derive(Debug, Clone)]
pub struct MemoryEntry {
pub conversation_id: ConversationId,
pub role: String,
pub content: String,
pub parts: Vec<serde_json::Value>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct MemoryMatch {
pub content: String,
pub score: f32,
pub source: MemorySource,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemorySource {
Semantic,
Episodic,
Graph,
Keyword,
}
#[derive(Debug, Clone)]
pub struct CompactionContext {
pub conversation_id: ConversationId,
pub token_budget: usize,
}
#[derive(Debug, Clone)]
pub struct CompactionResult {
pub summary: String,
pub messages_compacted: usize,
}
#[allow(async_fn_in_trait)]
pub trait MemoryFacade: Send + Sync {
async fn remember(&self, entry: MemoryEntry) -> Result<MessageId, MemoryError>;
async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryMatch>, MemoryError>;
async fn summarize(&self, conv_id: ConversationId) -> Result<String, MemoryError>;
async fn compact(&self, ctx: &CompactionContext) -> Result<CompactionResult, MemoryError>;
}
#[derive(Debug, Default)]
pub struct InMemoryFacade {
entries: Mutex<BTreeMap<i64, MemoryEntry>>,
next_id: Mutex<i64>,
}
impl InMemoryFacade {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.lock().map_or(0, |g| g.len())
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl MemoryFacade for InMemoryFacade {
async fn remember(&self, entry: MemoryEntry) -> Result<MessageId, MemoryError> {
let mut id_guard = self
.next_id
.lock()
.map_err(|e| MemoryError::Other(format!("InMemoryFacade lock poisoned: {e}")))?;
*id_guard += 1;
let id = *id_guard;
let mut entries = self
.entries
.lock()
.map_err(|e| MemoryError::Other(format!("InMemoryFacade lock poisoned: {e}")))?;
entries.insert(id, entry);
Ok(MessageId(id))
}
async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryMatch>, MemoryError> {
let entries = self
.entries
.lock()
.map_err(|e| MemoryError::Other(format!("InMemoryFacade lock poisoned: {e}")))?;
let query_lower = query.to_lowercase();
let mut matches: Vec<MemoryMatch> = entries
.values()
.filter(|e| e.content.to_lowercase().contains(&query_lower))
.map(|e| MemoryMatch {
content: e.content.clone(),
score: 1.0,
source: MemorySource::Keyword,
})
.take(limit)
.collect();
matches.sort_by(|a, b| a.content.cmp(&b.content));
Ok(matches)
}
async fn summarize(&self, conv_id: ConversationId) -> Result<String, MemoryError> {
let entries = self
.entries
.lock()
.map_err(|e| MemoryError::Other(format!("InMemoryFacade lock poisoned: {e}")))?;
let texts: Vec<&str> = entries
.values()
.filter(|e| e.conversation_id == conv_id)
.map(|e| e.content.as_str())
.collect();
Ok(texts.join("\n"))
}
async fn compact(&self, ctx: &CompactionContext) -> Result<CompactionResult, MemoryError> {
let mut entries = self
.entries
.lock()
.map_err(|e| MemoryError::Other(format!("InMemoryFacade lock poisoned: {e}")))?;
let ids_to_remove: Vec<i64> = entries
.iter()
.filter(|(_, e)| e.conversation_id == ctx.conversation_id)
.map(|(&id, _)| id)
.collect();
let count = ids_to_remove.len();
let summary: Vec<String> = ids_to_remove
.iter()
.filter_map(|id| entries.get(id).map(|e| e.content.clone()))
.collect();
for id in &ids_to_remove {
entries.remove(id);
}
Ok(CompactionResult {
summary: summary.join("\n"),
messages_compacted: count,
})
}
}
impl MemoryFacade for crate::semantic::SemanticMemory {
async fn remember(&self, entry: MemoryEntry) -> Result<MessageId, MemoryError> {
let parts_json = serde_json::to_string(&entry.parts)
.map_err(|e| MemoryError::Other(format!("parts serialization failed: {e}")))?;
let (id_opt, _embedded) = self
.remember_with_parts(
entry.conversation_id,
&entry.role,
&entry.content,
&parts_json,
None,
)
.await?;
id_opt.ok_or_else(|| MemoryError::Other("message rejected by admission control".into()))
}
async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryMatch>, MemoryError> {
let recalled = self.recall(query, limit, None).await?;
Ok(recalled
.into_iter()
.map(|r| MemoryMatch {
content: r.message.content,
score: r.score,
source: MemorySource::Semantic,
})
.collect())
}
async fn summarize(&self, conv_id: ConversationId) -> Result<String, MemoryError> {
let summaries = self.load_summaries(conv_id).await?;
Ok(summaries
.into_iter()
.map(|s| s.content)
.collect::<Vec<_>>()
.join("\n"))
}
async fn compact(&self, ctx: &CompactionContext) -> Result<CompactionResult, MemoryError> {
let before = self.message_count(ctx.conversation_id).await?;
let messages_compacted = usize::try_from(before).unwrap_or(0);
let target_msgs = ctx.token_budget.checked_div(4).unwrap_or(512);
let _ = self.summarize(ctx.conversation_id, target_msgs).await?;
let summary = self
.load_summaries(ctx.conversation_id)
.await?
.into_iter()
.map(|s| s.content)
.collect::<Vec<_>>()
.join("\n");
Ok(CompactionResult {
summary,
messages_compacted,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn remember_and_recall() {
let facade = InMemoryFacade::new();
let entry = MemoryEntry {
conversation_id: ConversationId(1),
role: "user".into(),
content: "Rust ownership model".into(),
parts: vec![],
metadata: None,
};
let id = facade.remember(entry).await.unwrap();
assert_eq!(id, MessageId(1));
let matches = facade.recall("ownership", 10).await.unwrap();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].content, "Rust ownership model");
assert_eq!(matches[0].source, MemorySource::Keyword);
}
#[tokio::test]
async fn recall_no_match() {
let facade = InMemoryFacade::new();
let entry = MemoryEntry {
conversation_id: ConversationId(1),
role: "user".into(),
content: "Rust ownership model".into(),
parts: vec![],
metadata: None,
};
facade.remember(entry).await.unwrap();
let matches = facade.recall("Python", 10).await.unwrap();
assert!(matches.is_empty());
}
#[tokio::test]
async fn summarize_joins_content() {
let facade = InMemoryFacade::new();
for content in ["Hello", "World"] {
facade
.remember(MemoryEntry {
conversation_id: ConversationId(1),
role: "user".into(),
content: content.into(),
parts: vec![],
metadata: None,
})
.await
.unwrap();
}
let summary = facade.summarize(ConversationId(1)).await.unwrap();
assert!(summary.contains("Hello") && summary.contains("World"));
}
#[tokio::test]
async fn compact_removes_conversation_entries() {
let facade = InMemoryFacade::new();
facade
.remember(MemoryEntry {
conversation_id: ConversationId(1),
role: "user".into(),
content: "entry 1".into(),
parts: vec![],
metadata: None,
})
.await
.unwrap();
facade
.remember(MemoryEntry {
conversation_id: ConversationId(2),
role: "user".into(),
content: "other conv".into(),
parts: vec![],
metadata: None,
})
.await
.unwrap();
let result = facade
.compact(&CompactionContext {
conversation_id: ConversationId(1),
token_budget: 100,
})
.await
.unwrap();
assert_eq!(result.messages_compacted, 1);
assert!(result.summary.contains("entry 1"));
assert_eq!(facade.len(), 1);
}
#[tokio::test]
async fn recall_respects_limit() {
let facade = InMemoryFacade::new();
for i in 0..5 {
facade
.remember(MemoryEntry {
conversation_id: ConversationId(1),
role: "user".into(),
content: format!("memory item {i}"),
parts: vec![],
metadata: None,
})
.await
.unwrap();
}
let matches = facade.recall("memory", 3).await.unwrap();
assert_eq!(matches.len(), 3);
}
}