use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
use async_trait::async_trait;
use std::sync::Arc;
pub struct NamespacedMemory {
inner: Arc<dyn Memory>,
namespace: String,
}
impl NamespacedMemory {
pub fn new(inner: Arc<dyn Memory>, namespace: String) -> Self {
Self { inner, namespace }
}
pub fn namespace(&self) -> &str {
&self.namespace
}
}
#[async_trait]
impl Memory for NamespacedMemory {
fn name(&self) -> &str {
self.inner.name()
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> {
self.inner
.store_with_metadata(
key,
content,
category,
session_id,
Some(&self.namespace),
None,
)
.await
}
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.inner
.recall_namespaced(&self.namespace, query, limit, session_id, since, until)
.await
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
let entry = self.inner.get(key).await?;
Ok(entry.filter(|e| e.namespace == self.namespace))
}
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let entries = self.inner.list(category, session_id).await?;
Ok(entries
.into_iter()
.filter(|e| e.namespace == self.namespace)
.collect())
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
if let Some(entry) = self.inner.get(key).await? {
if entry.namespace == self.namespace {
return self.inner.forget(key).await;
}
}
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
let entries = self.inner.list(None, None).await?;
Ok(entries
.into_iter()
.filter(|e| e.namespace == self.namespace)
.count())
}
async fn health_check(&self) -> bool {
self.inner.health_check().await
}
async fn store_procedural(
&self,
messages: &[ProceduralMessage],
session_id: Option<&str>,
) -> anyhow::Result<()> {
self.inner.store_procedural(messages, session_id).await
}
async fn recall_namespaced(
&self,
namespace: &str,
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
if namespace == self.namespace {
self.inner
.recall_namespaced(&self.namespace, query, limit, session_id, since, until)
.await
} else {
Ok(Vec::new())
}
}
async fn store_with_metadata(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
_namespace: Option<&str>,
importance: Option<f64>,
) -> anyhow::Result<()> {
self.inner
.store_with_metadata(
key,
content,
category,
session_id,
Some(&self.namespace),
importance,
)
.await
}
async fn purge_namespace(&self, namespace: &str) -> anyhow::Result<usize> {
if namespace == self.namespace {
self.inner.purge_namespace(namespace).await
} else {
anyhow::bail!(
"Cannot purge namespace '{}' from isolation context '{}'",
namespace,
self.namespace
)
}
}
async fn purge_session(&self, session_id: &str) -> anyhow::Result<usize> {
let entries = self.inner.list(None, Some(session_id)).await?;
let mut count = 0;
for entry in entries {
if entry.namespace == self.namespace && self.inner.forget(&entry.key).await? {
count += 1;
}
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::NoneMemory;
#[tokio::test]
async fn namespaced_memory_enforces_namespace_on_store() {
let inner = Arc::new(NoneMemory::new());
let namespaced = NamespacedMemory::new(inner, "test_namespace".to_string());
namespaced
.store("key1", "value1", MemoryCategory::Core, None)
.await
.unwrap();
}
#[tokio::test]
async fn namespaced_memory_prevents_cross_namespace_access() {
let inner = Arc::new(NoneMemory::new());
let namespaced = NamespacedMemory::new(inner, "test_namespace".to_string());
let results = namespaced
.recall_namespaced("other_namespace", "query", 10, None, None, None)
.await
.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn namespaced_memory_delegates_correctly() {
let inner = Arc::new(NoneMemory::new());
let namespaced = NamespacedMemory::new(inner, "test_namespace".to_string());
assert_eq!(namespaced.name(), "none");
assert!(namespaced.health_check().await);
assert_eq!(namespaced.count().await.unwrap(), 0);
}
}