use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::engine::MemoryEngine;
use crate::models::{Document, Node};
use crate::traits::{KVStore, Result};
use super::agent::ChatMessage;
pub struct CerebroMemoryBus {
pub semantic: Arc<MemoryEngine>,
pub working: Arc<dyn KVStore>,
episodic: Arc<RwLock<HashMap<String, Vec<ChatMessage>>>>,
}
impl CerebroMemoryBus {
pub fn new(semantic: Arc<MemoryEngine>, working: Arc<dyn KVStore>) -> Self {
Self {
semantic,
working,
episodic: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn commit_to_semantic(
&self,
agent_id: &str,
run_id: &str,
output: &str,
) -> Result<String> {
let doc = Document::new(output)
.with_metadata("agent_id", agent_id)
.with_metadata("run_id", run_id)
.with_metadata("source", "swarm_agent")
.with_metadata("timestamp", &chrono::Utc::now().to_rfc3339());
let doc_id = doc.id.clone();
self.semantic.ingest_document(doc).await?;
Ok(doc_id)
}
pub async fn recall_semantic(&self, query: &str, top_k: usize) -> Vec<(Node, f32)> {
self.semantic.query(query, top_k).await.unwrap_or_default()
}
pub async fn recall_as_context(&self, query: &str, top_k: usize) -> String {
let results = self.recall_semantic(query, top_k).await;
if results.is_empty() {
return String::new();
}
let mut context = String::from("=== Relevant Prior Knowledge ===\n");
for (i, (node, score)) in results.iter().enumerate() {
let _agent = node.chunk.document_id.clone(); context.push_str(&format!(
"\n[{}] (relevance: {:.3})\n{}\n",
i + 1,
score,
node.chunk.text
));
}
context.push_str("\n=== End Prior Knowledge ===\n");
context
}
pub async fn set_state(&self, agent_id: &str, key: &str, value: &str) -> Result<()> {
let full_key = format!("agent:{}:{}", agent_id, key);
self.working.set(&full_key, value).await
}
pub async fn get_state(&self, agent_id: &str, key: &str) -> Result<Option<String>> {
let full_key = format!("agent:{}:{}", agent_id, key);
self.working.get(&full_key).await
}
pub async fn set_global(&self, key: &str, value: &str) -> Result<()> {
let full_key = format!("swarm:{}", key);
self.working.set(&full_key, value).await
}
pub async fn get_global(&self, key: &str) -> Result<Option<String>> {
let full_key = format!("swarm:{}", key);
self.working.get(&full_key).await
}
pub fn push_message(&self, agent_id: &str, message: ChatMessage) {
if let Ok(mut store) = self.episodic.write() {
store
.entry(agent_id.to_string())
.or_insert_with(Vec::new)
.push(message);
}
}
pub fn get_history(&self, agent_id: &str) -> Vec<ChatMessage> {
self.episodic
.read()
.map(|store| store.get(agent_id).cloned().unwrap_or_default())
.unwrap_or_default()
}
pub fn get_recent_history(&self, agent_id: &str, n: usize) -> Vec<ChatMessage> {
let history = self.get_history(agent_id);
if history.len() <= n {
history
} else {
history[history.len() - n..].to_vec()
}
}
pub fn clear_episodic(&self) {
if let Ok(mut store) = self.episodic.write() {
store.clear();
}
}
pub fn active_agents(&self) -> Vec<String> {
self.episodic
.read()
.map(|store| store.keys().cloned().collect())
.unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chunker::RecursiveCharacterChunker;
use crate::compute::mock::MockEmbedder;
use crate::storage::kv::MemoryKVStore;
use crate::storage::memory::MemoryVectorStore;
use crate::swarm::agent::Role;
fn build_memory_bus() -> CerebroMemoryBus {
let chunker = Arc::new(RecursiveCharacterChunker::new(512, 50));
let embedder = Arc::new(MockEmbedder::new(8));
let store = Arc::new(MemoryVectorStore::new());
let engine = Arc::new(MemoryEngine::new(chunker, embedder, store));
let kv = Arc::new(MemoryKVStore::new());
CerebroMemoryBus::new(engine, kv)
}
#[tokio::test]
async fn test_working_memory_set_get() {
let bus = build_memory_bus();
bus.set_state("agent-1", "current_task", "Review security")
.await
.unwrap();
let val = bus.get_state("agent-1", "current_task").await.unwrap();
assert_eq!(val, Some("Review security".to_string()));
}
#[tokio::test]
async fn test_working_memory_isolation() {
let bus = build_memory_bus();
bus.set_state("agent-1", "step", "3").await.unwrap();
bus.set_state("agent-2", "step", "7").await.unwrap();
let v1 = bus.get_state("agent-1", "step").await.unwrap();
let v2 = bus.get_state("agent-2", "step").await.unwrap();
assert_eq!(v1, Some("3".to_string()));
assert_eq!(v2, Some("7".to_string()));
}
#[tokio::test]
async fn test_global_state() {
let bus = build_memory_bus();
bus.set_global("run_status", "running").await.unwrap();
let val = bus.get_global("run_status").await.unwrap();
assert_eq!(val, Some("running".to_string()));
}
#[tokio::test]
async fn test_semantic_commit_and_recall() {
let bus = build_memory_bus();
let doc_id = bus
.commit_to_semantic("agent-1", "run-001", "Rust ensures memory safety without a garbage collector through its ownership system.")
.await
.unwrap();
assert!(!doc_id.is_empty());
let results = bus.recall_semantic("memory safety", 5).await;
assert!(!results.is_empty());
}
#[tokio::test]
async fn test_recall_as_context_empty() {
let bus = build_memory_bus();
let ctx = bus.recall_as_context("anything", 5).await;
assert!(ctx.is_empty());
}
#[test]
fn test_episodic_push_and_get() {
let bus = build_memory_bus();
bus.push_message("agent-1", ChatMessage::new(Role::User, "Hello"));
bus.push_message("agent-1", ChatMessage::new(Role::Assistant, "Hi there!"));
bus.push_message("agent-2", ChatMessage::new(Role::User, "Different agent"));
let h1 = bus.get_history("agent-1");
assert_eq!(h1.len(), 2);
assert_eq!(h1[0].content, "Hello");
assert_eq!(h1[1].content, "Hi there!");
let h2 = bus.get_history("agent-2");
assert_eq!(h2.len(), 1);
}
#[test]
fn test_episodic_recent_history() {
let bus = build_memory_bus();
for i in 0..10 {
bus.push_message("agent-1", ChatMessage::new(Role::User, format!("msg-{}", i)));
}
let recent = bus.get_recent_history("agent-1", 3);
assert_eq!(recent.len(), 3);
assert_eq!(recent[0].content, "msg-7");
assert_eq!(recent[2].content, "msg-9");
}
#[test]
fn test_episodic_clear() {
let bus = build_memory_bus();
bus.push_message("agent-1", ChatMessage::new(Role::User, "Hello"));
assert_eq!(bus.get_history("agent-1").len(), 1);
bus.clear_episodic();
assert_eq!(bus.get_history("agent-1").len(), 0);
}
#[test]
fn test_active_agents() {
let bus = build_memory_bus();
bus.push_message("agent-a", ChatMessage::new(Role::User, "hi"));
bus.push_message("agent-b", ChatMessage::new(Role::User, "hey"));
let agents = bus.active_agents();
assert_eq!(agents.len(), 2);
assert!(agents.contains(&"agent-a".to_string()));
assert!(agents.contains(&"agent-b".to_string()));
}
}