use neo4rs::{query, Graph};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::models::{Chunk};
use crate::traits::{CerebroError, Result};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EntityTriple {
pub subject: String,
pub predicate: String,
pub object: String,
}
pub struct GraphMemoryLayer {
graph: Arc<Graph>,
}
impl GraphMemoryLayer {
pub async fn new(uri: &str, user: &str, pass: &str) -> Result<Self> {
let graph = Graph::new(uri, user, pass)
.await
.map_err(|e| CerebroError::StorageError(format!("Neo4j connection error: {}", e)))?;
Ok(Self { graph: Arc::new(graph) })
}
pub async fn extract_knowledge(&self, chunk: &Chunk) -> Result<Vec<EntityTriple>> {
println!("Extracting knowledge from chunk: {}", chunk.index);
let sample = EntityTriple {
subject: "Cerebro".to_string(),
predicate: "stores".to_string(),
object: "Memory Nodes".to_string(),
};
Ok(vec![sample])
}
pub async fn upsert_triplets(&self, triplets: &[EntityTriple], doc_id: &str) -> Result<()> {
let mut txn = self.graph.start_txn()
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
for t in triplets {
let cql = format!(
"MERGE (s:Entity {{name: $subject}})
MERGE (o:Entity {{name: $object}})
MERGE (s)-[r:RELATION {{type: $predicate, doc_id: $doc_id}}]->(o)"
);
txn.run(query(&cql)
.param("subject", t.subject.clone())
.param("object", t.object.clone())
.param("predicate", t.predicate.clone())
.param("doc_id", doc_id.to_string())
)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
}
txn.commit().await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
}