Skip to main content

mem7_graph/
neo4j.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use mem7_core::MemoryFilter;
5use mem7_error::{Mem7Error, Result};
6use neo4rs::{Graph, query};
7use tracing::{debug, info};
8
9use crate::GraphStore;
10use crate::types::{Entity, GraphSearchResult, Relation};
11
12/// Neo4j-backed graph store for production use.
13pub struct Neo4jGraphStore {
14    graph: Arc<Graph>,
15}
16
17impl Neo4jGraphStore {
18    /// Connect to a Neo4j instance and ensure schema constraints exist.
19    pub async fn new(
20        url: &str,
21        username: &str,
22        password: &str,
23        database: Option<&str>,
24    ) -> Result<Self> {
25        let mut config = neo4rs::ConfigBuilder::default()
26            .uri(url)
27            .user(username)
28            .password(password);
29
30        if let Some(db) = database {
31            config = config.db(db);
32        }
33
34        let graph = Graph::connect(
35            config
36                .build()
37                .map_err(|e| Mem7Error::Graph(format!("Neo4j config error: {e}")))?,
38        )
39        .await
40        .map_err(|e| Mem7Error::Graph(format!("Neo4j connection error: {e}")))?;
41
42        graph
43            .run(query(
44                "CREATE CONSTRAINT IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE",
45            ))
46            .await
47            .map_err(|e| Mem7Error::Graph(format!("constraint creation error: {e}")))?;
48
49        info!(url, "Neo4jGraphStore connected");
50        Ok(Self {
51            graph: Arc::new(graph),
52        })
53    }
54}
55
56#[async_trait]
57impl GraphStore for Neo4jGraphStore {
58    async fn add_entities(&self, entities: &[Entity], filter: &MemoryFilter) -> Result<()> {
59        for entity in entities {
60            let q = query(
61                "MERGE (e:Entity {name: $name}) \
62                 ON CREATE SET e.entity_type = $entity_type, \
63                               e.user_id = $user_id, \
64                               e.agent_id = $agent_id, \
65                               e.run_id = $run_id, \
66                               e.created_at = timestamp(), \
67                               e.mentions = 1 \
68                 ON MATCH SET  e.mentions = e.mentions + 1",
69            )
70            .param("name", entity.name.as_str())
71            .param("entity_type", entity.entity_type.as_str())
72            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
73            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
74            .param("run_id", filter.run_id.as_deref().unwrap_or(""));
75
76            self.graph
77                .run(q)
78                .await
79                .map_err(|e| Mem7Error::Graph(format!("add entity error: {e}")))?;
80
81            if let Some(emb) = &entity.embedding {
82                let emb_q = query(
83                    "MATCH (e:Entity {name: $name}) \
84                     CALL db.create.setNodeVectorProperty(e, 'embedding', $embedding)",
85                )
86                .param("name", entity.name.as_str())
87                .param("embedding", emb.clone());
88
89                self.graph
90                    .run(emb_q)
91                    .await
92                    .map_err(|e| Mem7Error::Graph(format!("set embedding error: {e}")))?;
93            }
94        }
95
96        debug!(count = entities.len(), "neo4j: entities added");
97        Ok(())
98    }
99
100    async fn add_relations(
101        &self,
102        relations: &[Relation],
103        entities: &[Entity],
104        filter: &MemoryFilter,
105    ) -> Result<()> {
106        self.add_entities(entities, filter).await?;
107
108        for rel in relations {
109            let q = query(
110                "MATCH (s:Entity {name: $src}), (d:Entity {name: $dst}) \
111                 MERGE (s)-[r:RELATES {relationship: $rel}]->(d) \
112                 ON CREATE SET r.user_id = $user_id, \
113                               r.agent_id = $agent_id, \
114                               r.run_id = $run_id, \
115                               r.created_at = timestamp(), \
116                               r.mentions = 1, \
117                               r.valid = true \
118                 ON MATCH SET  r.mentions = r.mentions + 1",
119            )
120            .param("src", rel.source.as_str())
121            .param("dst", rel.destination.as_str())
122            .param("rel", rel.relationship.as_str())
123            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
124            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
125            .param("run_id", filter.run_id.as_deref().unwrap_or(""));
126
127            self.graph
128                .run(q)
129                .await
130                .map_err(|e| Mem7Error::Graph(format!("add relation error: {e}")))?;
131        }
132
133        debug!(count = relations.len(), "neo4j: relations added");
134        Ok(())
135    }
136
137    async fn search(
138        &self,
139        query_str: &str,
140        filter: &MemoryFilter,
141        limit: usize,
142    ) -> Result<Vec<GraphSearchResult>> {
143        let cypher = "\
144            MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
145            WHERE r.valid = true \
146                  AND (toLower(s.name) CONTAINS toLower($query) \
147                       OR toLower(d.name) CONTAINS toLower($query) \
148                       OR toLower(r.relationship) CONTAINS toLower($query)) \
149                  AND ($user_id = '' OR r.user_id = $user_id) \
150                  AND ($agent_id = '' OR r.agent_id = $agent_id) \
151                  AND ($run_id = '' OR r.run_id = $run_id) \
152            RETURN s.name AS source, r.relationship AS relationship, d.name AS destination \
153            LIMIT $limit";
154
155        let q = query(cypher)
156            .param("query", query_str)
157            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
158            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
159            .param("run_id", filter.run_id.as_deref().unwrap_or(""))
160            .param("limit", limit as i64);
161
162        let mut result = self
163            .graph
164            .execute(q)
165            .await
166            .map_err(|e| Mem7Error::Graph(format!("search error: {e}")))?;
167
168        let mut results = Vec::new();
169        while let Ok(Some(row)) = result.next().await {
170            let source: String = row.get("source").unwrap_or_default();
171            let relationship: String = row.get("relationship").unwrap_or_default();
172            let destination: String = row.get("destination").unwrap_or_default();
173
174            results.push(GraphSearchResult {
175                source,
176                relationship,
177                destination,
178                score: None,
179                created_at: None,
180                mentions: None,
181                last_accessed_at: None,
182            });
183        }
184
185        debug!(count = results.len(), "neo4j: search results");
186        Ok(results)
187    }
188
189    async fn search_by_embedding(
190        &self,
191        embedding: &[f32],
192        filter: &MemoryFilter,
193        threshold: f32,
194        limit: usize,
195    ) -> Result<Vec<GraphSearchResult>> {
196        let cypher = "\
197            MATCH (n:Entity) \
198            WHERE n.embedding IS NOT NULL \
199                  AND ($user_id = '' OR n.user_id = $user_id) \
200                  AND ($agent_id = '' OR n.agent_id = $agent_id) \
201                  AND ($run_id = '' OR n.run_id = $run_id) \
202            WITH n, vector.similarity.cosine(n.embedding, $embedding) AS similarity \
203            WHERE similarity >= $threshold \
204            CALL { \
205                WITH n \
206                MATCH (n)-[r:RELATES]->(m:Entity) \
207                WHERE r.valid = true \
208                  AND ($user_id = '' OR r.user_id = $user_id) \
209                  AND ($agent_id = '' OR r.agent_id = $agent_id) \
210                  AND ($run_id = '' OR r.run_id = $run_id) \
211                RETURN n.name AS source, r.relationship AS relationship, m.name AS destination, similarity, \
212                       r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
213                UNION \
214                WITH n, similarity \
215                MATCH (n)<-[r:RELATES]-(m:Entity) \
216                WHERE r.valid = true \
217                  AND ($user_id = '' OR r.user_id = $user_id) \
218                  AND ($agent_id = '' OR r.agent_id = $agent_id) \
219                  AND ($run_id = '' OR r.run_id = $run_id) \
220                RETURN m.name AS source, r.relationship AS relationship, n.name AS destination, similarity, \
221                       r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
222            } \
223            RETURN DISTINCT source, relationship, destination, similarity, rel_created_at, rel_mentions, rel_last_accessed \
224            ORDER BY similarity DESC \
225            LIMIT $limit";
226
227        let q = query(cypher)
228            .param("embedding", embedding.to_vec())
229            .param("threshold", threshold as f64)
230            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
231            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
232            .param("run_id", filter.run_id.as_deref().unwrap_or(""))
233            .param("limit", limit as i64);
234
235        let mut result = self
236            .graph
237            .execute(q)
238            .await
239            .map_err(|e| Mem7Error::Graph(format!("embedding search error: {e}")))?;
240
241        let mut results = Vec::new();
242        while let Ok(Some(row)) = result.next().await {
243            let source: String = row.get("source").unwrap_or_default();
244            let relationship: String = row.get("relationship").unwrap_or_default();
245            let destination: String = row.get("destination").unwrap_or_default();
246            let similarity: f64 = row.get("similarity").unwrap_or_default();
247            let rel_created_at: Option<i64> = row.get("rel_created_at").ok();
248            let rel_mentions: Option<i64> = row.get("rel_mentions").ok();
249            let rel_last_accessed: Option<String> = row.get("rel_last_accessed").ok();
250
251            results.push(GraphSearchResult {
252                source,
253                relationship,
254                destination,
255                score: Some(similarity as f32),
256                created_at: rel_created_at.map(|ts| format!("{ts}")),
257                mentions: rel_mentions.map(|m| m as u32),
258                last_accessed_at: rel_last_accessed,
259            });
260        }
261
262        debug!(count = results.len(), "neo4j: embedding search results");
263        Ok(results)
264    }
265
266    async fn invalidate_relations(
267        &self,
268        triples: &[(String, String, String)],
269        filter: &MemoryFilter,
270    ) -> Result<()> {
271        for (src, rel, dst) in triples {
272            let q = query(
273                "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
274                 WHERE ($user_id = '' OR r.user_id = $user_id) \
275                   AND ($agent_id = '' OR r.agent_id = $agent_id) \
276                   AND ($run_id = '' OR r.run_id = $run_id) \
277                 SET r.valid = false",
278            )
279            .param("src", src.as_str())
280            .param("rel", rel.as_str())
281            .param("dst", dst.as_str())
282            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
283            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
284            .param("run_id", filter.run_id.as_deref().unwrap_or(""));
285
286            self.graph
287                .run(q)
288                .await
289                .map_err(|e| Mem7Error::Graph(format!("invalidate relation error: {e}")))?;
290        }
291
292        debug!(count = triples.len(), "neo4j: relations invalidated");
293        Ok(())
294    }
295
296    async fn rehearse_relations(
297        &self,
298        triples: &[(String, String, String)],
299        filter: &MemoryFilter,
300        now: &str,
301    ) -> Result<()> {
302        for (src, rel, dst) in triples {
303            let q = query(
304                "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
305                 WHERE r.valid = true \
306                   AND ($user_id = '' OR r.user_id = $user_id) \
307                   AND ($agent_id = '' OR r.agent_id = $agent_id) \
308                   AND ($run_id = '' OR r.run_id = $run_id) \
309                 SET r.mentions = r.mentions + 1, r.last_accessed_at = $now",
310            )
311            .param("src", src.as_str())
312            .param("rel", rel.as_str())
313            .param("dst", dst.as_str())
314            .param("user_id", filter.user_id.as_deref().unwrap_or(""))
315            .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
316            .param("run_id", filter.run_id.as_deref().unwrap_or(""))
317            .param("now", now);
318
319            self.graph
320                .run(q)
321                .await
322                .map_err(|e| Mem7Error::Graph(format!("rehearse relation error: {e}")))?;
323        }
324
325        debug!(count = triples.len(), "neo4j: relations rehearsed");
326        Ok(())
327    }
328
329    async fn delete_all(&self, filter: &MemoryFilter) -> Result<()> {
330        let q = query(
331            "MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
332             WHERE ($user_id = '' OR r.user_id = $user_id) \
333               AND ($agent_id = '' OR r.agent_id = $agent_id) \
334               AND ($run_id = '' OR r.run_id = $run_id) \
335             DELETE r",
336        )
337        .param("user_id", filter.user_id.as_deref().unwrap_or(""))
338        .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
339        .param("run_id", filter.run_id.as_deref().unwrap_or(""));
340
341        self.graph
342            .run(q)
343            .await
344            .map_err(|e| Mem7Error::Graph(format!("delete relations error: {e}")))?;
345
346        self.graph
347            .run(query("MATCH (e:Entity) WHERE NOT (e)--() DELETE e"))
348            .await
349            .map_err(|e| Mem7Error::Graph(format!("delete orphan entities error: {e}")))?;
350
351        Ok(())
352    }
353
354    async fn reset(&self) -> Result<()> {
355        self.graph
356            .run(query("MATCH ()-[r:RELATES]->() DELETE r"))
357            .await
358            .map_err(|e| Mem7Error::Graph(format!("reset relations error: {e}")))?;
359
360        self.graph
361            .run(query("MATCH (e:Entity) DELETE e"))
362            .await
363            .map_err(|e| Mem7Error::Graph(format!("reset entities error: {e}")))?;
364
365        info!("neo4j: graph reset");
366        Ok(())
367    }
368}