cerebro 1.1.8

A blazing-fast AI memory layer that enables teams of specialized agents to collaborate through a shared cognitive architecture.
Documentation
use crate::models::Chunk;
use crate::swarm::agent::{ChatMessage, LlmProvider, Role};
use crate::swarm::llm::LlmClient;
use crate::traits::{CerebroError, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

/// An extracted relationship triple representing a concrete factual memory.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct EntityTriple {
    pub subject: String,
    pub predicate: String,
    pub object: String,
}

#[async_trait]
pub trait GraphStore: Send + Sync {
    async fn upsert_triplets(&self, triplets: &[EntityTriple], doc_id: &str) -> Result<()>;
    async fn query_triplets(&self, query: &str) -> Result<Vec<EntityTriple>>;
}

pub mod neo4j {
    use super::*;
    use neo4rs::{query, Graph};

    pub struct Neo4jGraphStore {
        graph: Arc<Graph>,
    }

    impl Neo4jGraphStore {
        pub async fn new(uri: &str, user: &str, pass: &str) -> Result<Self> {
            let graph = Graph::new(uri, user, pass).await.map_err(|e| {
                CerebroError::StorageError(format!("Neo4j connection error: {}", e))
            })?;
            Ok(Self {
                graph: Arc::new(graph),
            })
        }
    }

    #[async_trait]
    impl GraphStore for Neo4jGraphStore {
        async fn upsert_triplets(&self, triplets: &[EntityTriple], doc_id: &str) -> Result<()> {
            let mut txn = self
                .graph
                .start_txn()
                .await
                .map_err(|e| CerebroError::StorageError(e.to_string()))?;

            for t in triplets {
                let cql = "MERGE (s:Entity {name: $subject})
                     MERGE (o:Entity {name: $object})
                     MERGE (s)-[r:RELATION {type: $predicate, doc_id: $doc_id}]->(o)"
                    .to_string();

                txn.run(
                    query(&cql)
                        .param("subject", t.subject.clone())
                        .param("object", t.object.clone())
                        .param("predicate", t.predicate.clone())
                        .param("doc_id", doc_id.to_string()),
                )
                .await
                .map_err(|e| CerebroError::StorageError(e.to_string()))?;
            }

            txn.commit()
                .await
                .map_err(|e| CerebroError::StorageError(e.to_string()))?;
            Ok(())
        }

        async fn query_triplets(&self, _query_str: &str) -> Result<Vec<EntityTriple>> {
            // Placeholder for Cypher query execution mapping back to EntityTriple
            Ok(vec![])
        }
    }
}

pub mod memory {
    use super::*;
    use petgraph::graph::{DiGraph, NodeIndex};
    use std::collections::HashMap;
    use tokio::sync::RwLock;

    pub struct MemoryGraphStore {
        graph: RwLock<DiGraph<String, String>>,
        node_indices: RwLock<HashMap<String, NodeIndex>>,
    }

    impl Default for MemoryGraphStore {
        fn default() -> Self {
            Self::new()
        }
    }

    impl MemoryGraphStore {
        pub fn new() -> Self {
            Self {
                graph: RwLock::new(DiGraph::new()),
                node_indices: RwLock::new(HashMap::new()),
            }
        }
    }

    #[async_trait]
    impl GraphStore for MemoryGraphStore {
        async fn upsert_triplets(&self, triplets: &[EntityTriple], _doc_id: &str) -> Result<()> {
            let mut graph = self.graph.write().await;
            let mut node_indices = self.node_indices.write().await;

            for t in triplets {
                let subject_idx = *node_indices
                    .entry(t.subject.clone())
                    .or_insert_with(|| graph.add_node(t.subject.clone()));

                let object_idx = *node_indices
                    .entry(t.object.clone())
                    .or_insert_with(|| graph.add_node(t.object.clone()));

                graph.add_edge(subject_idx, object_idx, t.predicate.clone());
            }

            Ok(())
        }

        async fn query_triplets(&self, query_str: &str) -> Result<Vec<EntityTriple>> {
            let graph = self.graph.read().await;

            let mut results = Vec::new();
            let query_lower = query_str.to_lowercase();

            for edge in graph.edge_indices() {
                if let Some((src, dst)) = graph.edge_endpoints(edge) {
                    if let (Some(subject), Some(object), Some(predicate)) = (
                        graph.node_weight(src),
                        graph.node_weight(dst),
                        graph.edge_weight(edge),
                    ) {
                        if subject.to_lowercase().contains(&query_lower)
                            || object.to_lowercase().contains(&query_lower)
                            || predicate.to_lowercase().contains(&query_lower)
                        {
                            results.push(EntityTriple {
                                subject: subject.clone(),
                                predicate: predicate.clone(),
                                object: object.clone(),
                            });
                        }
                    }
                }
            }

            Ok(results)
        }
    }
}

pub struct GraphMemoryLayer {
    store: Arc<dyn GraphStore>,
    llm_client: LlmClient,
    provider: Option<LlmProvider>,
}

impl GraphMemoryLayer {
    pub fn new(store: Arc<dyn GraphStore>, provider: Option<LlmProvider>) -> Self {
        Self {
            store,
            llm_client: LlmClient::new(),
            provider,
        }
    }

    pub async fn extract_knowledge(&self, chunk: &Chunk) -> Result<Vec<EntityTriple>> {
        let provider = match &self.provider {
            Some(p) => p,
            None => return Ok(vec![]),
        };

        let system_prompt = "You are a Knowledge Graph extractor. Given a text chunk, extract key facts as a JSON array of objects, each containing 'subject', 'predicate', and 'object'. Only output the raw JSON array. Example: [{\"subject\": \"Cerebro\", \"predicate\": \"uses\", \"object\": \"Neo4j\"}]";

        let messages = vec![
            ChatMessage::new(Role::System, system_prompt),
            ChatMessage::new(Role::User, chunk.text.clone()),
        ];

        let response = self.llm_client.chat(provider, &messages).await?;

        let json_text = response
            .content
            .trim()
            .trim_start_matches("```json")
            .trim_start_matches("```")
            .trim_end_matches("```")
            .trim();

        let triplets: Vec<EntityTriple> =
            serde_json::from_str(json_text).unwrap_or_else(|_| vec![]);

        Ok(triplets)
    }

    pub async fn upsert_triplets(&self, triplets: &[EntityTriple], doc_id: &str) -> Result<()> {
        self.store.upsert_triplets(triplets, doc_id).await
    }

    pub async fn query_graph(&self, query: &str) -> Result<Vec<EntityTriple>> {
        self.store.query_triplets(query).await
    }
}

#[cfg(test)]
mod tests {
    use super::memory::MemoryGraphStore;
    use super::*;

    #[tokio::test]
    async fn test_memory_graph_store() {
        let store = MemoryGraphStore::new();
        let triplets = vec![
            EntityTriple {
                subject: "Agent".to_string(),
                predicate: "uses".to_string(),
                object: "Tool".to_string(),
            },
            EntityTriple {
                subject: "Tool".to_string(),
                predicate: "modifies".to_string(),
                object: "State".to_string(),
            },
        ];

        store.upsert_triplets(&triplets, "doc1").await.unwrap();

        // Query for "tool" should return both triplets
        let results = store.query_triplets("tool").await.unwrap();
        assert_eq!(results.len(), 2);

        // Query for "agent" should return just the first triplet
        let results_agent = store.query_triplets("agent").await.unwrap();
        assert_eq!(results_agent.len(), 1);
        assert_eq!(results_agent[0].subject, "Agent");
    }
}