1use std::sync::Arc;
2
3use async_trait::async_trait;
4use mem7_core::MemoryFilter;
5use mem7_error::{Mem7Error, Result};
6use neo4rs::{Graph, query};
7use tracing::{debug, info};
8
9use crate::GraphStore;
10use crate::types::{Entity, GraphSearchResult, Relation};
11
12pub struct Neo4jGraphStore {
14 graph: Arc<Graph>,
15}
16
17impl Neo4jGraphStore {
18 pub async fn new(
20 url: &str,
21 username: &str,
22 password: &str,
23 database: Option<&str>,
24 ) -> Result<Self> {
25 let mut config = neo4rs::ConfigBuilder::default()
26 .uri(url)
27 .user(username)
28 .password(password);
29
30 if let Some(db) = database {
31 config = config.db(db);
32 }
33
34 let graph = Graph::connect(
35 config
36 .build()
37 .map_err(|e| Mem7Error::Graph(format!("Neo4j config error: {e}")))?,
38 )
39 .await
40 .map_err(|e| Mem7Error::Graph(format!("Neo4j connection error: {e}")))?;
41
42 graph
43 .run(query(
44 "CREATE CONSTRAINT IF NOT EXISTS FOR (e:Entity) REQUIRE e.name IS UNIQUE",
45 ))
46 .await
47 .map_err(|e| Mem7Error::Graph(format!("constraint creation error: {e}")))?;
48
49 info!(url, "Neo4jGraphStore connected");
50 Ok(Self {
51 graph: Arc::new(graph),
52 })
53 }
54}
55
56#[async_trait]
57impl GraphStore for Neo4jGraphStore {
58 async fn add_entities(&self, entities: &[Entity], filter: &MemoryFilter) -> Result<()> {
59 for entity in entities {
60 let q = query(
61 "MERGE (e:Entity {name: $name}) \
62 ON CREATE SET e.entity_type = $entity_type, \
63 e.user_id = $user_id, \
64 e.agent_id = $agent_id, \
65 e.run_id = $run_id, \
66 e.created_at = timestamp(), \
67 e.mentions = 1 \
68 ON MATCH SET e.mentions = e.mentions + 1",
69 )
70 .param("name", entity.name.as_str())
71 .param("entity_type", entity.entity_type.as_str())
72 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
73 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
74 .param("run_id", filter.run_id.as_deref().unwrap_or(""));
75
76 self.graph
77 .run(q)
78 .await
79 .map_err(|e| Mem7Error::Graph(format!("add entity error: {e}")))?;
80
81 if let Some(emb) = &entity.embedding {
82 let emb_q = query(
83 "MATCH (e:Entity {name: $name}) \
84 CALL db.create.setNodeVectorProperty(e, 'embedding', $embedding)",
85 )
86 .param("name", entity.name.as_str())
87 .param("embedding", emb.clone());
88
89 self.graph
90 .run(emb_q)
91 .await
92 .map_err(|e| Mem7Error::Graph(format!("set embedding error: {e}")))?;
93 }
94 }
95
96 debug!(count = entities.len(), "neo4j: entities added");
97 Ok(())
98 }
99
100 async fn add_relations(
101 &self,
102 relations: &[Relation],
103 entities: &[Entity],
104 filter: &MemoryFilter,
105 ) -> Result<()> {
106 self.add_entities(entities, filter).await?;
107
108 for rel in relations {
109 let q = query(
110 "MATCH (s:Entity {name: $src}), (d:Entity {name: $dst}) \
111 MERGE (s)-[r:RELATES {relationship: $rel}]->(d) \
112 ON CREATE SET r.user_id = $user_id, \
113 r.agent_id = $agent_id, \
114 r.run_id = $run_id, \
115 r.created_at = timestamp(), \
116 r.mentions = 1, \
117 r.valid = true \
118 ON MATCH SET r.mentions = r.mentions + 1",
119 )
120 .param("src", rel.source.as_str())
121 .param("dst", rel.destination.as_str())
122 .param("rel", rel.relationship.as_str())
123 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
124 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
125 .param("run_id", filter.run_id.as_deref().unwrap_or(""));
126
127 self.graph
128 .run(q)
129 .await
130 .map_err(|e| Mem7Error::Graph(format!("add relation error: {e}")))?;
131 }
132
133 debug!(count = relations.len(), "neo4j: relations added");
134 Ok(())
135 }
136
137 async fn search(
138 &self,
139 query_str: &str,
140 filter: &MemoryFilter,
141 limit: usize,
142 ) -> Result<Vec<GraphSearchResult>> {
143 let cypher = "\
144 MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
145 WHERE r.valid = true \
146 AND (toLower(s.name) CONTAINS toLower($query) \
147 OR toLower(d.name) CONTAINS toLower($query) \
148 OR toLower(r.relationship) CONTAINS toLower($query)) \
149 AND ($user_id = '' OR r.user_id = $user_id) \
150 AND ($agent_id = '' OR r.agent_id = $agent_id) \
151 AND ($run_id = '' OR r.run_id = $run_id) \
152 RETURN s.name AS source, r.relationship AS relationship, d.name AS destination \
153 LIMIT $limit";
154
155 let q = query(cypher)
156 .param("query", query_str)
157 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
158 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
159 .param("run_id", filter.run_id.as_deref().unwrap_or(""))
160 .param("limit", limit as i64);
161
162 let mut result = self
163 .graph
164 .execute(q)
165 .await
166 .map_err(|e| Mem7Error::Graph(format!("search error: {e}")))?;
167
168 let mut results = Vec::new();
169 while let Ok(Some(row)) = result.next().await {
170 let source: String = row.get("source").unwrap_or_default();
171 let relationship: String = row.get("relationship").unwrap_or_default();
172 let destination: String = row.get("destination").unwrap_or_default();
173
174 results.push(GraphSearchResult {
175 source,
176 relationship,
177 destination,
178 score: None,
179 created_at: None,
180 mentions: None,
181 last_accessed_at: None,
182 });
183 }
184
185 debug!(count = results.len(), "neo4j: search results");
186 Ok(results)
187 }
188
189 async fn search_by_embedding(
190 &self,
191 embedding: &[f32],
192 filter: &MemoryFilter,
193 threshold: f32,
194 limit: usize,
195 ) -> Result<Vec<GraphSearchResult>> {
196 let cypher = "\
197 MATCH (n:Entity) \
198 WHERE n.embedding IS NOT NULL \
199 AND ($user_id = '' OR n.user_id = $user_id) \
200 AND ($agent_id = '' OR n.agent_id = $agent_id) \
201 AND ($run_id = '' OR n.run_id = $run_id) \
202 WITH n, vector.similarity.cosine(n.embedding, $embedding) AS similarity \
203 WHERE similarity >= $threshold \
204 CALL { \
205 WITH n \
206 MATCH (n)-[r:RELATES]->(m:Entity) \
207 WHERE r.valid = true \
208 AND ($user_id = '' OR r.user_id = $user_id) \
209 AND ($agent_id = '' OR r.agent_id = $agent_id) \
210 AND ($run_id = '' OR r.run_id = $run_id) \
211 RETURN n.name AS source, r.relationship AS relationship, m.name AS destination, similarity, \
212 r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
213 UNION \
214 WITH n, similarity \
215 MATCH (n)<-[r:RELATES]-(m:Entity) \
216 WHERE r.valid = true \
217 AND ($user_id = '' OR r.user_id = $user_id) \
218 AND ($agent_id = '' OR r.agent_id = $agent_id) \
219 AND ($run_id = '' OR r.run_id = $run_id) \
220 RETURN m.name AS source, r.relationship AS relationship, n.name AS destination, similarity, \
221 r.created_at AS rel_created_at, r.mentions AS rel_mentions, r.last_accessed_at AS rel_last_accessed \
222 } \
223 RETURN DISTINCT source, relationship, destination, similarity, rel_created_at, rel_mentions, rel_last_accessed \
224 ORDER BY similarity DESC \
225 LIMIT $limit";
226
227 let q = query(cypher)
228 .param("embedding", embedding.to_vec())
229 .param("threshold", threshold as f64)
230 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
231 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
232 .param("run_id", filter.run_id.as_deref().unwrap_or(""))
233 .param("limit", limit as i64);
234
235 let mut result = self
236 .graph
237 .execute(q)
238 .await
239 .map_err(|e| Mem7Error::Graph(format!("embedding search error: {e}")))?;
240
241 let mut results = Vec::new();
242 while let Ok(Some(row)) = result.next().await {
243 let source: String = row.get("source").unwrap_or_default();
244 let relationship: String = row.get("relationship").unwrap_or_default();
245 let destination: String = row.get("destination").unwrap_or_default();
246 let similarity: f64 = row.get("similarity").unwrap_or_default();
247 let rel_created_at: Option<i64> = row.get("rel_created_at").ok();
248 let rel_mentions: Option<i64> = row.get("rel_mentions").ok();
249 let rel_last_accessed: Option<String> = row.get("rel_last_accessed").ok();
250
251 results.push(GraphSearchResult {
252 source,
253 relationship,
254 destination,
255 score: Some(similarity as f32),
256 created_at: rel_created_at.map(|ts| format!("{ts}")),
257 mentions: rel_mentions.map(|m| m as u32),
258 last_accessed_at: rel_last_accessed,
259 });
260 }
261
262 debug!(count = results.len(), "neo4j: embedding search results");
263 Ok(results)
264 }
265
266 async fn invalidate_relations(
267 &self,
268 triples: &[(String, String, String)],
269 filter: &MemoryFilter,
270 ) -> Result<()> {
271 for (src, rel, dst) in triples {
272 let q = query(
273 "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
274 WHERE ($user_id = '' OR r.user_id = $user_id) \
275 AND ($agent_id = '' OR r.agent_id = $agent_id) \
276 AND ($run_id = '' OR r.run_id = $run_id) \
277 SET r.valid = false",
278 )
279 .param("src", src.as_str())
280 .param("rel", rel.as_str())
281 .param("dst", dst.as_str())
282 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
283 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
284 .param("run_id", filter.run_id.as_deref().unwrap_or(""));
285
286 self.graph
287 .run(q)
288 .await
289 .map_err(|e| Mem7Error::Graph(format!("invalidate relation error: {e}")))?;
290 }
291
292 debug!(count = triples.len(), "neo4j: relations invalidated");
293 Ok(())
294 }
295
296 async fn rehearse_relations(
297 &self,
298 triples: &[(String, String, String)],
299 filter: &MemoryFilter,
300 now: &str,
301 ) -> Result<()> {
302 for (src, rel, dst) in triples {
303 let q = query(
304 "MATCH (s:Entity {name: $src})-[r:RELATES {relationship: $rel}]->(d:Entity {name: $dst}) \
305 WHERE r.valid = true \
306 AND ($user_id = '' OR r.user_id = $user_id) \
307 AND ($agent_id = '' OR r.agent_id = $agent_id) \
308 AND ($run_id = '' OR r.run_id = $run_id) \
309 SET r.mentions = r.mentions + 1, r.last_accessed_at = $now",
310 )
311 .param("src", src.as_str())
312 .param("rel", rel.as_str())
313 .param("dst", dst.as_str())
314 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
315 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
316 .param("run_id", filter.run_id.as_deref().unwrap_or(""))
317 .param("now", now);
318
319 self.graph
320 .run(q)
321 .await
322 .map_err(|e| Mem7Error::Graph(format!("rehearse relation error: {e}")))?;
323 }
324
325 debug!(count = triples.len(), "neo4j: relations rehearsed");
326 Ok(())
327 }
328
329 async fn delete_all(&self, filter: &MemoryFilter) -> Result<()> {
330 let q = query(
331 "MATCH (s:Entity)-[r:RELATES]->(d:Entity) \
332 WHERE ($user_id = '' OR r.user_id = $user_id) \
333 AND ($agent_id = '' OR r.agent_id = $agent_id) \
334 AND ($run_id = '' OR r.run_id = $run_id) \
335 DELETE r",
336 )
337 .param("user_id", filter.user_id.as_deref().unwrap_or(""))
338 .param("agent_id", filter.agent_id.as_deref().unwrap_or(""))
339 .param("run_id", filter.run_id.as_deref().unwrap_or(""));
340
341 self.graph
342 .run(q)
343 .await
344 .map_err(|e| Mem7Error::Graph(format!("delete relations error: {e}")))?;
345
346 self.graph
347 .run(query("MATCH (e:Entity) WHERE NOT (e)--() DELETE e"))
348 .await
349 .map_err(|e| Mem7Error::Graph(format!("delete orphan entities error: {e}")))?;
350
351 Ok(())
352 }
353
354 async fn reset(&self) -> Result<()> {
355 self.graph
356 .run(query("MATCH ()-[r:RELATES]->() DELETE r"))
357 .await
358 .map_err(|e| Mem7Error::Graph(format!("reset relations error: {e}")))?;
359
360 self.graph
361 .run(query("MATCH (e:Entity) DELETE e"))
362 .await
363 .map_err(|e| Mem7Error::Graph(format!("reset entities error: {e}")))?;
364
365 info!("neo4j: graph reset");
366 Ok(())
367 }
368}