Skip to main content

mem7_graph/
neo4j.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use mem7_core::MemoryFilter;
5use mem7_error::{Mem7Error, Result};
6use neo4rs::{Graph, query};
7use tracing::{debug, info};
8
9use crate::GraphStore;
10use crate::types::{Entity, GraphSearchResult, Relation};
11
12/// Neo4j-backed graph store for production use.
13pub struct Neo4jGraphStore {
14    graph: Arc<Graph>,
15}
16
17impl Neo4jGraphStore {
18    /// Connect to a Neo4j instance and ensure schema constraints exist.
19    pub async fn new(
20        url: &str,
21        username: &str,
22        password: &str,
23        database: Option<&str>,
24    ) -> Result<Self> {
25        let mut config = neo4rs::ConfigBuilder::default()
26            .uri(url)
27            .user(username)
28            .password(password);
29
30        if let Some(db) = database {
31            config = config.db(db);
32        }
33
34        let graph = Graph::connect(
35            config
36                .build()
37                .map_err(|e| Mem7Error::Graph(format!("Neo4j config error: {e}")))?,
38        )
39        .await
40        .map_err(|e| Mem7Error::Graph(format!("Neo4j connection error: {e}")))?;
41
42        graph
43            .run(query(
44                "CREATE CONSTRAINT IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE",
45            ))
46            .await
47            .map_err(|e| Mem7Error::Graph(format!("constraint creation error: {e}")))?;
48
49        info!(url, "Neo4jGraphStore connected");
50        Ok(Self {
51            graph: Arc::new(graph),
52        })
53    }
54}
55
56#[async_trait]
57impl GraphStore for Neo4jGraphStore {
58    async fn add_entities(&self, entities: &[Entity], filter: &MemoryFilter) -> Result<()> {
59        for entity in entities {
60            let q = query(
61                "MERGE (e:Entity {name: $name}) \
62                 ON CREATE SET e.entity_type = $entity_type, \
63                               e.user_id = $user_id, \
64                               e.agent_id = $agent_id, \
65                               e.run_id = $run_id, \
66                               e.created_at = timestamp(), \
67                               e.mentions = 1 \
68                 ON MATCH SET  e.mentions = e.mentions + 1",
69            )
70            .param("name", entity.name.as_str())
71            .param("entity_type", entity.entity_type.as_str())
72            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
73            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
74            .param("run_id", filter.run_id.as_deref().unwrap_or(""));
75
76            self.graph
77                .run(q)
78                .await
79                .map_err(|e| Mem7Error::Graph(format!("add entity error: {e}")))?;
80
81            if let Some(emb) = &entity.embedding {
82                let emb_q = query(
83                    "MATCH (e:Entity {name: $name}) \
84                     CALL db.create.setNodeVectorProperty(e, 'embedding', $embedding)",
85                )
86                .param("name", entity.name.as_str())
87                .param("embedding", emb.clone());
88
89                self.graph
90                    .run(emb_q)
91                    .await
92                    .map_err(|e| Mem7Error::Graph(format!("set embedding error: {e}")))?;
93            }
94        }
95
96        debug!(count = entities.len(), "neo4j: entities added");
97        Ok(())
98    }
99
100    async fn add_relations(
101        &self,
102        relations: &[Relation],
103        entities: &[Entity],
104        filter: &MemoryFilter,
105    ) -> Result<()> {
106        self.add_entities(entities, filter).await?;
107
108        for rel in relations {
109            let q = query(
110                "MATCH (s:Entity {name: $src}), (d:Entity {name: $dst}) \
111                 MERGE (s)-[r:RELATES {relationship: $rel}]->(d) \
112                 ON CREATE SET r.user_id = $user_id, \
113                               r.agent_id = $agent_id, \
114                               r.run_id = $run_id, \
115                               r.created_at = timestamp(), \
116                               r.mentions = 1, \
117                               r.valid = true \
118                 ON MATCH SET  r.mentions = r.mentions + 1",
119            )
120            .param("src", rel.source.as_str())
121            .param("dst", rel.destination.as_str())
122            .param("rel", rel.relationship.as_str())
123            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
124            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
125            .param("run_id", filter.run_id.as_deref().unwrap_or(""));
126
127            self.graph
128                .run(q)
129                .await
130                .map_err(|e| Mem7Error::Graph(format!("add relation error: {e}")))?;
131        }
132
133        debug!(count = relations.len(), "neo4j: relations added");
134        Ok(())
135    }
136
137    async fn search(
138        &self,
139        query_str: &str,
140        filter: &MemoryFilter,
141        limit: usize,
142    ) -> Result<Vec<GraphSearchResult>> {
143        let cypher = "\
144            MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
145            WHERE r.valid = true \
146                  AND (toLower(s.name) CONTAINS toLower($query) \
147                       OR toLower(d.name) CONTAINS toLower($query) \
148                       OR toLower(r.relationship) CONTAINS toLower($query)) \
149                  AND ($user_id = '' OR r.user_id = $user_id) \
150                  AND ($agent_id = '' OR r.agent_id = $agent_id) \
151            RETURN s.name AS source, r.relationship AS relationship, d.name AS destination \
152            LIMIT $limit";
153
154        let q = query(cypher)
155            .param("query", query_str)
156            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
157            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
158            .param("limit", limit as i64);
159
160        let mut result = self
161            .graph
162            .execute(q)
163            .await
164            .map_err(|e| Mem7Error::Graph(format!("search error: {e}")))?;
165
166        let mut results = Vec::new();
167        while let Ok(Some(row)) = result.next().await {
168            let source: String = row.get("source").unwrap_or_default();
169            let relationship: String = row.get("relationship").unwrap_or_default();
170            let destination: String = row.get("destination").unwrap_or_default();
171
172            results.push(GraphSearchResult {
173                source,
174                relationship,
175                destination,
176                score: None,
177                created_at: None,
178                mentions: None,
179                last_accessed_at: None,
180            });
181        }
182
183        debug!(count = results.len(), "neo4j: search results");
184        Ok(results)
185    }
186
187    async fn search_by_embedding(
188        &self,
189        embedding: &[f32],
190        filter: &MemoryFilter,
191        threshold: f32,
192        limit: usize,
193    ) -> Result<Vec<GraphSearchResult>> {
194        let cypher = "\
195            MATCH (n:Entity) \
196            WHERE n.embedding IS NOT NULL \
197                  AND ($user_id = '' OR n.user_id = $user_id) \
198                  AND ($agent_id = '' OR n.agent_id = $agent_id) \
199            WITH n, vector.similarity.cosine(n.embedding, $embedding) AS similarity \
200            WHERE similarity >= $threshold \
201            CALL { \
202                WITH n \
203                MATCH (n)-[r:RELATES]->(m:Entity) WHERE r.valid = true \
204                RETURN n.name AS source, r.relationship AS relationship, m.name AS destination, similarity, \
205                       r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
206                UNION \
207                WITH n, similarity \
208                MATCH (n)<-[r:RELATES]-(m:Entity) WHERE r.valid = true \
209                RETURN m.name AS source, r.relationship AS relationship, n.name AS destination, similarity, \
210                       r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
211            } \
212            RETURN DISTINCT source, relationship, destination, similarity, rel_created_at, rel_mentions, rel_last_accessed \
213            ORDER BY similarity DESC \
214            LIMIT $limit";
215
216        let q = query(cypher)
217            .param("embedding", embedding.to_vec())
218            .param("threshold", threshold as f64)
219            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
220            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
221            .param("limit", limit as i64);
222
223        let mut result = self
224            .graph
225            .execute(q)
226            .await
227            .map_err(|e| Mem7Error::Graph(format!("embedding search error: {e}")))?;
228
229        let mut results = Vec::new();
230        while let Ok(Some(row)) = result.next().await {
231            let source: String = row.get("source").unwrap_or_default();
232            let relationship: String = row.get("relationship").unwrap_or_default();
233            let destination: String = row.get("destination").unwrap_or_default();
234            let similarity: f64 = row.get("similarity").unwrap_or_default();
235            let rel_created_at: Option<i64> = row.get("rel_created_at").ok();
236            let rel_mentions: Option<i64> = row.get("rel_mentions").ok();
237            let rel_last_accessed: Option<String> = row.get("rel_last_accessed").ok();
238
239            results.push(GraphSearchResult {
240                source,
241                relationship,
242                destination,
243                score: Some(similarity as f32),
244                created_at: rel_created_at.map(|ts| format!("{ts}")),
245                mentions: rel_mentions.map(|m| m as u32),
246                last_accessed_at: rel_last_accessed,
247            });
248        }
249
250        debug!(count = results.len(), "neo4j: embedding search results");
251        Ok(results)
252    }
253
254    async fn invalidate_relations(
255        &self,
256        triples: &[(String, String, String)],
257        filter: &MemoryFilter,
258    ) -> Result<()> {
259        for (src, rel, dst) in triples {
260            let q = query(
261                "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
262                 WHERE ($user_id = '' OR r.user_id = $user_id) \
263                 SET r.valid = false",
264            )
265            .param("src", src.as_str())
266            .param("rel", rel.as_str())
267            .param("dst", dst.as_str())
268            .param("user_id", filter.user_id.as_deref().unwrap_or(""));
269
270            self.graph
271                .run(q)
272                .await
273                .map_err(|e| Mem7Error::Graph(format!("invalidate relation error: {e}")))?;
274        }
275
276        debug!(count = triples.len(), "neo4j: relations invalidated");
277        Ok(())
278    }
279
280    async fn rehearse_relations(
281        &self,
282        triples: &[(String, String, String)],
283        filter: &MemoryFilter,
284        now: &str,
285    ) -> Result<()> {
286        for (src, rel, dst) in triples {
287            let q = query(
288                "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
289                 WHERE r.valid = true AND ($user_id = '' OR r.user_id = $user_id) \
290                 SET r.mentions = r.mentions + 1, r.last_accessed_at = $now",
291            )
292            .param("src", src.as_str())
293            .param("rel", rel.as_str())
294            .param("dst", dst.as_str())
295            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
296            .param("now", now);
297
298            self.graph
299                .run(q)
300                .await
301                .map_err(|e| Mem7Error::Graph(format!("rehearse relation error: {e}")))?;
302        }
303
304        debug!(count = triples.len(), "neo4j: relations rehearsed");
305        Ok(())
306    }
307
308    async fn delete_all(&self, filter: &MemoryFilter) -> Result<()> {
309        if let Some(uid) = &filter.user_id {
310            let q = query(
311                "MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
312                 WHERE r.user_id = $user_id DELETE r",
313            )
314            .param("user_id", uid.as_str());
315
316            self.graph
317                .run(q)
318                .await
319                .map_err(|e| Mem7Error::Graph(format!("delete relations error: {e}")))?;
320
321            let q = query("MATCH (e:Entity) WHERE e.user_id = $user_id DELETE e")
322                .param("user_id", uid.as_str());
323
324            self.graph
325                .run(q)
326                .await
327                .map_err(|e| Mem7Error::Graph(format!("delete entities error: {e}")))?;
328        }
329
330        Ok(())
331    }
332
333    async fn reset(&self) -> Result<()> {
334        self.graph
335            .run(query("MATCH ()-[r:RELATES]->() DELETE r"))
336            .await
337            .map_err(|e| Mem7Error::Graph(format!("reset relations error: {e}")))?;
338
339        self.graph
340            .run(query("MATCH (e:Entity) DELETE e"))
341            .await
342            .map_err(|e| Mem7Error::Graph(format!("reset entities error: {e}")))?;
343
344        info!("neo4j: graph reset");
345        Ok(())
346    }
347}