use crate::models::Chunk;
use crate::swarm::agent::{ChatMessage, LlmProvider, Role};
use crate::swarm::llm::LlmClient;
use crate::traits::{CerebroError, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct EntityTriple {
pub subject: String,
pub predicate: String,
pub object: String,
}
#[async_trait]
pub trait GraphStore: Send + Sync {
async fn upsert_triplets(&self, triplets: &[EntityTriple], doc_id: &str) -> Result<()>;
async fn query_triplets(&self, query: &str) -> Result<Vec<EntityTriple>>;
}
pub mod neo4j {
use super::*;
use neo4rs::{query, Graph};
pub struct Neo4jGraphStore {
graph: Arc<Graph>,
}
impl Neo4jGraphStore {
pub async fn new(uri: &str, user: &str, pass: &str) -> Result<Self> {
let graph = Graph::new(uri, user, pass).await.map_err(|e| {
CerebroError::StorageError(format!("Neo4j connection error: {}", e))
})?;
Ok(Self {
graph: Arc::new(graph),
})
}
}
#[async_trait]
impl GraphStore for Neo4jGraphStore {
async fn upsert_triplets(&self, triplets: &[EntityTriple], doc_id: &str) -> Result<()> {
let mut txn = self
.graph
.start_txn()
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
for t in triplets {
let cql = "MERGE (s:Entity {name: $subject})
MERGE (o:Entity {name: $object})
MERGE (s)-[r:RELATION {type: $predicate, doc_id: $doc_id}]->(o)"
.to_string();
txn.run(
query(&cql)
.param("subject", t.subject.clone())
.param("object", t.object.clone())
.param("predicate", t.predicate.clone())
.param("doc_id", doc_id.to_string()),
)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
}
txn.commit()
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
async fn query_triplets(&self, _query_str: &str) -> Result<Vec<EntityTriple>> {
Ok(vec![])
}
}
}
pub mod memory {
use super::*;
use petgraph::graph::{DiGraph, NodeIndex};
use std::collections::HashMap;
use tokio::sync::RwLock;
pub struct MemoryGraphStore {
graph: RwLock<DiGraph<String, String>>,
node_indices: RwLock<HashMap<String, NodeIndex>>,
}
impl Default for MemoryGraphStore {
fn default() -> Self {
Self::new()
}
}
impl MemoryGraphStore {
pub fn new() -> Self {
Self {
graph: RwLock::new(DiGraph::new()),
node_indices: RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl GraphStore for MemoryGraphStore {
async fn upsert_triplets(&self, triplets: &[EntityTriple], _doc_id: &str) -> Result<()> {
let mut graph = self.graph.write().await;
let mut node_indices = self.node_indices.write().await;
for t in triplets {
let subject_idx = *node_indices
.entry(t.subject.clone())
.or_insert_with(|| graph.add_node(t.subject.clone()));
let object_idx = *node_indices
.entry(t.object.clone())
.or_insert_with(|| graph.add_node(t.object.clone()));
graph.add_edge(subject_idx, object_idx, t.predicate.clone());
}
Ok(())
}
async fn query_triplets(&self, query_str: &str) -> Result<Vec<EntityTriple>> {
let graph = self.graph.read().await;
let mut results = Vec::new();
let query_lower = query_str.to_lowercase();
for edge in graph.edge_indices() {
if let Some((src, dst)) = graph.edge_endpoints(edge) {
if let (Some(subject), Some(object), Some(predicate)) = (
graph.node_weight(src),
graph.node_weight(dst),
graph.edge_weight(edge),
) {
if subject.to_lowercase().contains(&query_lower)
|| object.to_lowercase().contains(&query_lower)
|| predicate.to_lowercase().contains(&query_lower)
{
results.push(EntityTriple {
subject: subject.clone(),
predicate: predicate.clone(),
object: object.clone(),
});
}
}
}
}
Ok(results)
}
}
}
pub struct GraphMemoryLayer {
store: Arc<dyn GraphStore>,
llm_client: LlmClient,
provider: Option<LlmProvider>,
}
impl GraphMemoryLayer {
pub fn new(store: Arc<dyn GraphStore>, provider: Option<LlmProvider>) -> Self {
Self {
store,
llm_client: LlmClient::new(),
provider,
}
}
pub async fn extract_knowledge(&self, chunk: &Chunk) -> Result<Vec<EntityTriple>> {
let provider = match &self.provider {
Some(p) => p,
None => return Ok(vec![]),
};
let system_prompt = "You are a Knowledge Graph extractor. Given a text chunk, extract key facts as a JSON array of objects, each containing 'subject', 'predicate', and 'object'. Only output the raw JSON array. Example: [{\"subject\": \"Cerebro\", \"predicate\": \"uses\", \"object\": \"Neo4j\"}]";
let messages = vec![
ChatMessage::new(Role::System, system_prompt),
ChatMessage::new(Role::User, chunk.text.clone()),
];
let response = self.llm_client.chat(provider, &messages).await?;
let json_text = response
.content
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
let triplets: Vec<EntityTriple> =
serde_json::from_str(json_text).unwrap_or_else(|_| vec![]);
Ok(triplets)
}
pub async fn upsert_triplets(&self, triplets: &[EntityTriple], doc_id: &str) -> Result<()> {
self.store.upsert_triplets(triplets, doc_id).await
}
pub async fn query_graph(&self, query: &str) -> Result<Vec<EntityTriple>> {
self.store.query_triplets(query).await
}
}
#[cfg(test)]
mod tests {
use super::memory::MemoryGraphStore;
use super::*;
#[tokio::test]
async fn test_memory_graph_store() {
let store = MemoryGraphStore::new();
let triplets = vec![
EntityTriple {
subject: "Agent".to_string(),
predicate: "uses".to_string(),
object: "Tool".to_string(),
},
EntityTriple {
subject: "Tool".to_string(),
predicate: "modifies".to_string(),
object: "State".to_string(),
},
];
store.upsert_triplets(&triplets, "doc1").await.unwrap();
let results = store.query_triplets("tool").await.unwrap();
assert_eq!(results.len(), 2);
let results_agent = store.query_triplets("agent").await.unwrap();
assert_eq!(results_agent.len(), 1);
assert_eq!(results_agent[0].subject, "Agent");
}
}