1use std::sync::Arc;
2
3use async_trait::async_trait;
4use mem7_core::MemoryFilter;
5use mem7_error::{Mem7Error, Result};
6use neo4rs::{Graph, query};
7use tracing::{debug, info};
8
9use crate::GraphStore;
10use crate::types::{Entity, GraphSearchResult, Relation};
11
12pub struct Neo4jGraphStore {
14 graph: Arc<Graph>,
15}
16
17impl Neo4jGraphStore {
18 pub async fn new(
20 url: &str,
21 username: &str,
22 password: &str,
23 database: Option<&str>,
24 ) -> Result<Self> {
25 let mut config = neo4rs::ConfigBuilder::default()
26 .uri(url)
27 .user(username)
28 .password(password);
29
30 if let Some(db) = database {
31 config = config.db(db);
32 }
33
34 let graph = Graph::connect(
35 config
36 .build()
37 .map_err(|e| Mem7Error::Graph(format!("Neo4j config error: {e}")))?,
38 )
39 .await
40 .map_err(|e| Mem7Error::Graph(format!("Neo4j connection error: {e}")))?;
41
42 graph
43 .run(query(
44 "CREATE CONSTRAINT IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE",
45 ))
46 .await
47 .map_err(|e| Mem7Error::Graph(format!("constraint creation error: {e}")))?;
48
49 info!(url, "Neo4jGraphStore connected");
50 Ok(Self {
51 graph: Arc::new(graph),
52 })
53 }
54}
55
56#[async_trait]
57impl GraphStore for Neo4jGraphStore {
58 async fn add_entities(&self, entities: &[Entity], filter: &MemoryFilter) -> Result<()> {
59 for entity in entities {
60 let q = query(
61 "MERGE (e:Entity {name: $name}) \
62 ON CREATE SET e.entity_type = $entity_type, \
63 e.user_id = $user_id, \
64 e.agent_id = $agent_id, \
65 e.run_id = $run_id, \
66 e.created_at = timestamp(), \
67 e.mentions = 1 \
68 ON MATCH SET e.mentions = e.mentions + 1",
69 )
70 .param("name", entity.name.as_str())
71 .param("entity_type", entity.entity_type.as_str())
72 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
73 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
74 .param("run_id", filter.run_id.as_deref().unwrap_or(""));
75
76 self.graph
77 .run(q)
78 .await
79 .map_err(|e| Mem7Error::Graph(format!("add entity error: {e}")))?;
80
81 if let Some(emb) = &entity.embedding {
82 let emb_q = query(
83 "MATCH (e:Entity {name: $name}) \
84 CALL db.create.setNodeVectorProperty(e, 'embedding', $embedding)",
85 )
86 .param("name", entity.name.as_str())
87 .param("embedding", emb.clone());
88
89 self.graph
90 .run(emb_q)
91 .await
92 .map_err(|e| Mem7Error::Graph(format!("set embedding error: {e}")))?;
93 }
94 }
95
96 debug!(count = entities.len(), "neo4j: entities added");
97 Ok(())
98 }
99
100 async fn add_relations(
101 &self,
102 relations: &[Relation],
103 entities: &[Entity],
104 filter: &MemoryFilter,
105 ) -> Result<()> {
106 self.add_entities(entities, filter).await?;
107
108 for rel in relations {
109 let q = query(
110 "MATCH (s:Entity {name: $src}), (d:Entity {name: $dst}) \
111 MERGE (s)-[r:RELATES {relationship: $rel}]->(d) \
112 ON CREATE SET r.user_id = $user_id, \
113 r.agent_id = $agent_id, \
114 r.run_id = $run_id, \
115 r.created_at = timestamp(), \
116 r.mentions = 1, \
117 r.valid = true \
118 ON MATCH SET r.mentions = r.mentions + 1",
119 )
120 .param("src", rel.source.as_str())
121 .param("dst", rel.destination.as_str())
122 .param("rel", rel.relationship.as_str())
123 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
124 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
125 .param("run_id", filter.run_id.as_deref().unwrap_or(""));
126
127 self.graph
128 .run(q)
129 .await
130 .map_err(|e| Mem7Error::Graph(format!("add relation error: {e}")))?;
131 }
132
133 debug!(count = relations.len(), "neo4j: relations added");
134 Ok(())
135 }
136
137 async fn search(
138 &self,
139 query_str: &str,
140 filter: &MemoryFilter,
141 limit: usize,
142 ) -> Result<Vec<GraphSearchResult>> {
143 let cypher = "\
144 MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
145 WHERE r.valid = true \
146 AND (toLower(s.name) CONTAINS toLower($query) \
147 OR toLower(d.name) CONTAINS toLower($query) \
148 OR toLower(r.relationship) CONTAINS toLower($query)) \
149 AND ($user_id = '' OR r.user_id = $user_id) \
150 AND ($agent_id = '' OR r.agent_id = $agent_id) \
151 RETURN s.name AS source, r.relationship AS relationship, d.name AS destination \
152 LIMIT $limit";
153
154 let q = query(cypher)
155 .param("query", query_str)
156 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
157 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
158 .param("limit", limit as i64);
159
160 let mut result = self
161 .graph
162 .execute(q)
163 .await
164 .map_err(|e| Mem7Error::Graph(format!("search error: {e}")))?;
165
166 let mut results = Vec::new();
167 while let Ok(Some(row)) = result.next().await {
168 let source: String = row.get("source").unwrap_or_default();
169 let relationship: String = row.get("relationship").unwrap_or_default();
170 let destination: String = row.get("destination").unwrap_or_default();
171
172 results.push(GraphSearchResult {
173 source,
174 relationship,
175 destination,
176 score: None,
177 created_at: None,
178 mentions: None,
179 last_accessed_at: None,
180 });
181 }
182
183 debug!(count = results.len(), "neo4j: search results");
184 Ok(results)
185 }
186
187 async fn search_by_embedding(
188 &self,
189 embedding: &[f32],
190 filter: &MemoryFilter,
191 threshold: f32,
192 limit: usize,
193 ) -> Result<Vec<GraphSearchResult>> {
194 let cypher = "\
195 MATCH (n:Entity) \
196 WHERE n.embedding IS NOT NULL \
197 AND ($user_id = '' OR n.user_id = $user_id) \
198 AND ($agent_id = '' OR n.agent_id = $agent_id) \
199 WITH n, vector.similarity.cosine(n.embedding, $embedding) AS similarity \
200 WHERE similarity >= $threshold \
201 CALL { \
202 WITH n \
203 MATCH (n)-[r:RELATES]->(m:Entity) WHERE r.valid = true \
204 RETURN n.name AS source, r.relationship AS relationship, m.name AS destination, similarity, \
205 r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
206 UNION \
207 WITH n, similarity \
208 MATCH (n)<-[r:RELATES]-(m:Entity) WHERE r.valid = true \
209 RETURN m.name AS source, r.relationship AS relationship, n.name AS destination, similarity, \
210 r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
211 } \
212 RETURN DISTINCT source, relationship, destination, similarity, rel_created_at, rel_mentions, rel_last_accessed \
213 ORDER BY similarity DESC \
214 LIMIT $limit";
215
216 let q = query(cypher)
217 .param("embedding", embedding.to_vec())
218 .param("threshold", threshold as f64)
219 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
220 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
221 .param("limit", limit as i64);
222
223 let mut result = self
224 .graph
225 .execute(q)
226 .await
227 .map_err(|e| Mem7Error::Graph(format!("embedding search error: {e}")))?;
228
229 let mut results = Vec::new();
230 while let Ok(Some(row)) = result.next().await {
231 let source: String = row.get("source").unwrap_or_default();
232 let relationship: String = row.get("relationship").unwrap_or_default();
233 let destination: String = row.get("destination").unwrap_or_default();
234 let similarity: f64 = row.get("similarity").unwrap_or_default();
235 let rel_created_at: Option<i64> = row.get("rel_created_at").ok();
236 let rel_mentions: Option<i64> = row.get("rel_mentions").ok();
237 let rel_last_accessed: Option<String> = row.get("rel_last_accessed").ok();
238
239 results.push(GraphSearchResult {
240 source,
241 relationship,
242 destination,
243 score: Some(similarity as f32),
244 created_at: rel_created_at.map(|ts| format!("{ts}")),
245 mentions: rel_mentions.map(|m| m as u32),
246 last_accessed_at: rel_last_accessed,
247 });
248 }
249
250 debug!(count = results.len(), "neo4j: embedding search results");
251 Ok(results)
252 }
253
254 async fn invalidate_relations(
255 &self,
256 triples: &[(String, String, String)],
257 filter: &MemoryFilter,
258 ) -> Result<()> {
259 for (src, rel, dst) in triples {
260 let q = query(
261 "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
262 WHERE ($user_id = '' OR r.user_id = $user_id) \
263 SET r.valid = false",
264 )
265 .param("src", src.as_str())
266 .param("rel", rel.as_str())
267 .param("dst", dst.as_str())
268 .param("user_id", filter.user_id.as_deref().unwrap_or(""));
269
270 self.graph
271 .run(q)
272 .await
273 .map_err(|e| Mem7Error::Graph(format!("invalidate relation error: {e}")))?;
274 }
275
276 debug!(count = triples.len(), "neo4j: relations invalidated");
277 Ok(())
278 }
279
280 async fn rehearse_relations(
281 &self,
282 triples: &[(String, String, String)],
283 filter: &MemoryFilter,
284 now: &str,
285 ) -> Result<()> {
286 for (src, rel, dst) in triples {
287 let q = query(
288 "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
289 WHERE r.valid = true AND ($user_id = '' OR r.user_id = $user_id) \
290 SET r.mentions = r.mentions + 1, r.last_accessed_at = $now",
291 )
292 .param("src", src.as_str())
293 .param("rel", rel.as_str())
294 .param("dst", dst.as_str())
295 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
296 .param("now", now);
297
298 self.graph
299 .run(q)
300 .await
301 .map_err(|e| Mem7Error::Graph(format!("rehearse relation error: {e}")))?;
302 }
303
304 debug!(count = triples.len(), "neo4j: relations rehearsed");
305 Ok(())
306 }
307
308 async fn delete_all(&self, filter: &MemoryFilter) -> Result<()> {
309 if let Some(uid) = &filter.user_id {
310 let q = query(
311 "MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
312 WHERE r.user_id = $user_id DELETE r",
313 )
314 .param("user_id", uid.as_str());
315
316 self.graph
317 .run(q)
318 .await
319 .map_err(|e| Mem7Error::Graph(format!("delete relations error: {e}")))?;
320
321 let q = query("MATCH (e:Entity) WHERE e.user_id = $user_id DELETE e")
322 .param("user_id", uid.as_str());
323
324 self.graph
325 .run(q)
326 .await
327 .map_err(|e| Mem7Error::Graph(format!("delete entities error: {e}")))?;
328 }
329
330 Ok(())
331 }
332
333 async fn reset(&self) -> Result<()> {
334 self.graph
335 .run(query("MATCH ()-[r:RELATES]->() DELETE r"))
336 .await
337 .map_err(|e| Mem7Error::Graph(format!("reset relations error: {e}")))?;
338
339 self.graph
340 .run(query("MATCH (e:Entity) DELETE e"))
341 .await
342 .map_err(|e| Mem7Error::Graph(format!("reset entities error: {e}")))?;
343
344 info!("neo4j: graph reset");
345 Ok(())
346 }
347}