1use std::sync::RwLock;
2
3use async_trait::async_trait;
4use mem7_core::MemoryFilter;
5use mem7_error::Result;
6
7use crate::GraphStore;
8use crate::types::{Entity, GraphSearchResult, Relation};
9
10#[derive(Debug, Clone)]
11struct StoredEntity {
12 name: String,
13 entity_type: String,
14 embedding: Option<Vec<f32>>,
15 #[allow(dead_code)]
16 created_at: Option<String>,
17 mentions: u32,
18 #[allow(dead_code)]
19 last_accessed_at: Option<String>,
20 user_id: Option<String>,
21 agent_id: Option<String>,
22 run_id: Option<String>,
23}
24
25#[derive(Debug, Clone)]
26struct StoredRelation {
27 source: String,
28 relationship: String,
29 destination: String,
30 created_at: Option<String>,
31 mentions: u32,
32 valid: bool,
33 last_accessed_at: Option<String>,
34 user_id: Option<String>,
35 agent_id: Option<String>,
36 run_id: Option<String>,
37}
38
39pub struct FlatGraph {
41 entities: RwLock<Vec<StoredEntity>>,
42 relations: RwLock<Vec<StoredRelation>>,
43}
44
45impl FlatGraph {
46 pub fn new() -> Self {
47 Self {
48 entities: RwLock::new(Vec::new()),
49 relations: RwLock::new(Vec::new()),
50 }
51 }
52}
53
54impl Default for FlatGraph {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60fn matches_filter(
61 user_id: &Option<String>,
62 agent_id: &Option<String>,
63 run_id: &Option<String>,
64 filter: &MemoryFilter,
65) -> bool {
66 if let Some(uid) = &filter.user_id {
67 if user_id.as_deref() != Some(uid.as_str()) {
68 return false;
69 }
70 }
71 if let Some(aid) = &filter.agent_id {
72 if agent_id.as_deref() != Some(aid.as_str()) {
73 return false;
74 }
75 }
76 if let Some(rid) = &filter.run_id {
77 if run_id.as_deref() != Some(rid.as_str()) {
78 return false;
79 }
80 }
81 true
82}
83
84#[async_trait]
85impl GraphStore for FlatGraph {
86 async fn add_entities(&self, entities: &[Entity], filter: &MemoryFilter) -> Result<()> {
87 let mut store = self.entities.write().expect("entity lock poisoned");
88 for entity in entities {
89 if let Some(existing) = store.iter_mut().find(|e| {
90 e.name == entity.name && matches_filter(&e.user_id, &e.agent_id, &e.run_id, filter)
91 }) {
92 existing.mentions += 1;
93 if entity.embedding.is_some() {
94 existing.embedding.clone_from(&entity.embedding);
95 }
96 if entity.entity_type != existing.entity_type {
97 existing.entity_type.clone_from(&entity.entity_type);
98 }
99 } else {
100 store.push(StoredEntity {
101 name: entity.name.clone(),
102 entity_type: entity.entity_type.clone(),
103 embedding: entity.embedding.clone(),
104 created_at: entity.created_at.clone(),
105 mentions: 1,
106 last_accessed_at: entity.created_at.clone(),
107 user_id: filter.user_id.clone(),
108 agent_id: filter.agent_id.clone(),
109 run_id: filter.run_id.clone(),
110 });
111 }
112 }
113 Ok(())
114 }
115
116 async fn add_relations(
117 &self,
118 relations: &[Relation],
119 entities: &[Entity],
120 filter: &MemoryFilter,
121 ) -> Result<()> {
122 self.add_entities(entities, filter).await?;
123
124 let mut store = self.relations.write().expect("relation lock poisoned");
125 for r in relations {
126 if let Some(existing) = store.iter_mut().find(|e| {
127 e.source == r.source
128 && e.relationship == r.relationship
129 && e.destination == r.destination
130 && e.valid
131 && matches_filter(&e.user_id, &e.agent_id, &e.run_id, filter)
132 }) {
133 existing.mentions += 1;
134 } else {
135 store.push(StoredRelation {
136 source: r.source.clone(),
137 relationship: r.relationship.clone(),
138 destination: r.destination.clone(),
139 created_at: r.created_at.clone(),
140 mentions: 1,
141 valid: true,
142 last_accessed_at: r.created_at.clone(),
143 user_id: filter.user_id.clone(),
144 agent_id: filter.agent_id.clone(),
145 run_id: filter.run_id.clone(),
146 });
147 }
148 }
149 Ok(())
150 }
151
152 async fn search(
153 &self,
154 query: &str,
155 filter: &MemoryFilter,
156 limit: usize,
157 ) -> Result<Vec<GraphSearchResult>> {
158 let store = self.relations.read().expect("relation lock poisoned");
159 let query_lower = query.to_lowercase();
160
161 let results: Vec<GraphSearchResult> = store
162 .iter()
163 .filter(|r| {
164 r.valid
165 && matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter)
166 && (r.source.to_lowercase().contains(&query_lower)
167 || r.destination.to_lowercase().contains(&query_lower)
168 || r.relationship.to_lowercase().contains(&query_lower))
169 })
170 .take(limit)
171 .map(|r| GraphSearchResult {
172 source: r.source.clone(),
173 relationship: r.relationship.clone(),
174 destination: r.destination.clone(),
175 score: None,
176 created_at: r.created_at.clone(),
177 mentions: Some(r.mentions),
178 last_accessed_at: r.last_accessed_at.clone(),
179 })
180 .collect();
181
182 Ok(results)
183 }
184
185 async fn search_by_embedding(
186 &self,
187 embedding: &[f32],
188 filter: &MemoryFilter,
189 threshold: f32,
190 limit: usize,
191 ) -> Result<Vec<GraphSearchResult>> {
192 let entities = self.entities.read().expect("entity lock poisoned");
193
194 let matched_names: Vec<(&str, f32)> = entities
196 .iter()
197 .filter(|e| matches_filter(&e.user_id, &e.agent_id, &e.run_id, filter))
198 .filter_map(|e| {
199 e.embedding.as_ref().map(|emb| {
200 let sim = mem7_vector::cosine_similarity(emb, embedding);
201 (e.name.as_str(), sim)
202 })
203 })
204 .filter(|(_, sim)| *sim >= threshold)
205 .collect();
206
207 if matched_names.is_empty() {
208 return Ok(Vec::new());
209 }
210
211 let relations = self.relations.read().expect("relation lock poisoned");
213 let mut results: Vec<GraphSearchResult> = Vec::new();
214 let mut seen = std::collections::HashSet::new();
215
216 for (name, sim) in &matched_names {
217 for r in relations.iter() {
218 if !r.valid || !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter) {
219 continue;
220 }
221 if r.source.as_str() == *name || r.destination.as_str() == *name {
222 let key = (
223 r.source.clone(),
224 r.relationship.clone(),
225 r.destination.clone(),
226 );
227 if seen.insert(key) {
228 results.push(GraphSearchResult {
229 source: r.source.clone(),
230 relationship: r.relationship.clone(),
231 destination: r.destination.clone(),
232 score: Some(*sim),
233 created_at: r.created_at.clone(),
234 mentions: Some(r.mentions),
235 last_accessed_at: r.last_accessed_at.clone(),
236 });
237 }
238 }
239 }
240 }
241
242 results.sort_by(|a, b| {
243 b.score
244 .unwrap_or(0.0)
245 .partial_cmp(&a.score.unwrap_or(0.0))
246 .unwrap_or(std::cmp::Ordering::Equal)
247 });
248 results.truncate(limit);
249
250 Ok(results)
251 }
252
253 async fn invalidate_relations(
254 &self,
255 triples: &[(String, String, String)],
256 filter: &MemoryFilter,
257 ) -> Result<()> {
258 let mut store = self.relations.write().expect("relation lock poisoned");
259 for r in store.iter_mut() {
260 if !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter) {
261 continue;
262 }
263 for (src, rel, dst) in triples {
264 if r.source == *src && r.relationship == *rel && r.destination == *dst && r.valid {
265 r.valid = false;
266 }
267 }
268 }
269 Ok(())
270 }
271
272 async fn rehearse_relations(
273 &self,
274 triples: &[(String, String, String)],
275 filter: &MemoryFilter,
276 now: &str,
277 ) -> Result<()> {
278 let mut store = self.relations.write().expect("relation lock poisoned");
279 for r in store.iter_mut() {
280 if !r.valid || !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter) {
281 continue;
282 }
283 for (src, rel, dst) in triples {
284 if r.source == *src && r.relationship == *rel && r.destination == *dst {
285 r.mentions += 1;
286 r.last_accessed_at = Some(now.to_string());
287 }
288 }
289 }
290 Ok(())
291 }
292
293 async fn delete_all(&self, filter: &MemoryFilter) -> Result<()> {
294 let mut rel_store = self.relations.write().expect("relation lock poisoned");
295 rel_store.retain(|r| !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter));
296
297 let referenced_entities: std::collections::HashSet<String> = rel_store
298 .iter()
299 .flat_map(|r| [r.source.clone(), r.destination.clone()])
300 .collect();
301
302 let mut ent_store = self.entities.write().expect("entity lock poisoned");
303 ent_store.retain(|e| referenced_entities.contains(&e.name));
304
305 Ok(())
306 }
307
308 async fn reset(&self) -> Result<()> {
309 self.relations
310 .write()
311 .expect("relation lock poisoned")
312 .clear();
313 self.entities.write().expect("entity lock poisoned").clear();
314 Ok(())
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 fn test_filter(user_id: &str) -> MemoryFilter {
323 MemoryFilter {
324 user_id: Some(user_id.to_string()),
325 agent_id: None,
326 run_id: None,
327 metadata: None,
328 }
329 }
330
331 fn scoped_filter(user_id: &str, agent_id: &str, run_id: &str) -> MemoryFilter {
332 MemoryFilter {
333 user_id: Some(user_id.to_string()),
334 agent_id: Some(agent_id.to_string()),
335 run_id: Some(run_id.to_string()),
336 metadata: None,
337 }
338 }
339
340 fn make_entity(name: &str, etype: &str, embedding: Option<Vec<f32>>) -> Entity {
341 Entity {
342 name: name.into(),
343 entity_type: etype.into(),
344 embedding,
345 created_at: None,
346 mentions: 0,
347 }
348 }
349
350 fn make_relation(src: &str, rel: &str, dst: &str) -> Relation {
351 Relation {
352 source: src.into(),
353 relationship: rel.into(),
354 destination: dst.into(),
355 created_at: None,
356 mentions: 0,
357 valid: true,
358 }
359 }
360
361 #[tokio::test]
362 async fn add_and_search_relations() {
363 let graph = FlatGraph::new();
364 let filter = test_filter("user1");
365
366 let entities = vec![
367 make_entity("Alice", "Person", None),
368 make_entity("tennis", "Activity", None),
369 ];
370 let relations = vec![make_relation("Alice", "loves_playing", "tennis")];
371
372 graph
373 .add_relations(&relations, &entities, &filter)
374 .await
375 .unwrap();
376
377 let results = graph.search("Alice", &filter, 10).await.unwrap();
378 assert_eq!(results.len(), 1);
379 assert_eq!(results[0].source, "Alice");
380 assert_eq!(results[0].relationship, "loves_playing");
381 assert_eq!(results[0].destination, "tennis");
382 }
383
384 #[tokio::test]
385 async fn add_entities_stores_and_upserts() {
386 let graph = FlatGraph::new();
387 let filter = test_filter("user1");
388
389 let entities = vec![make_entity("Alice", "Person", Some(vec![1.0, 0.0]))];
390 graph.add_entities(&entities, &filter).await.unwrap();
391
392 graph.add_entities(&entities, &filter).await.unwrap();
394
395 let store = graph.entities.read().unwrap();
396 assert_eq!(store.len(), 1);
397 assert_eq!(store[0].mentions, 2);
398 assert!(store[0].embedding.is_some());
399 }
400
401 #[tokio::test]
402 async fn search_by_embedding_finds_related() {
403 let graph = FlatGraph::new();
404 let filter = test_filter("user1");
405
406 let entities = vec![
407 make_entity("Alice", "Person", Some(vec![1.0, 0.0, 0.0])),
408 make_entity("Bob", "Person", Some(vec![0.0, 1.0, 0.0])),
409 ];
410 let relations = vec![
411 make_relation("Alice", "friend_of", "Bob"),
412 make_relation("Alice", "likes", "tennis"),
413 ];
414
415 graph
416 .add_relations(&relations, &entities, &filter)
417 .await
418 .unwrap();
419
420 let query_emb = vec![0.99, 0.01, 0.0];
422 let results = graph
423 .search_by_embedding(&query_emb, &filter, 0.7, 10)
424 .await
425 .unwrap();
426
427 assert_eq!(results.len(), 2);
429 }
430
431 #[tokio::test]
432 async fn search_by_embedding_respects_threshold() {
433 let graph = FlatGraph::new();
434 let filter = test_filter("user1");
435
436 let entities = vec![make_entity("Alice", "Person", Some(vec![1.0, 0.0]))];
437 let relations = vec![make_relation("Alice", "likes", "coffee")];
438
439 graph
440 .add_relations(&relations, &entities, &filter)
441 .await
442 .unwrap();
443
444 let query_emb = vec![0.0, 1.0];
446 let results = graph
447 .search_by_embedding(&query_emb, &filter, 0.7, 10)
448 .await
449 .unwrap();
450 assert!(results.is_empty());
451 }
452
453 #[tokio::test]
454 async fn invalidate_relations_soft_deletes() {
455 let graph = FlatGraph::new();
456 let filter = test_filter("user1");
457
458 let entities = vec![make_entity("USER", "Person", None)];
459 let relations = vec![
460 make_relation("USER", "works_at", "Google"),
461 make_relation("USER", "lives_in", "NYC"),
462 ];
463
464 graph
465 .add_relations(&relations, &entities, &filter)
466 .await
467 .unwrap();
468
469 graph
471 .invalidate_relations(
472 &[("USER".into(), "works_at".into(), "Google".into())],
473 &filter,
474 )
475 .await
476 .unwrap();
477
478 let results = graph.search("USER", &filter, 10).await.unwrap();
480 assert_eq!(results.len(), 1);
481 assert_eq!(results[0].relationship, "lives_in");
482 }
483
484 #[tokio::test]
485 async fn relation_dedup_increments_mentions() {
486 let graph = FlatGraph::new();
487 let filter = test_filter("user1");
488
489 let entities = vec![make_entity("Alice", "Person", None)];
490 let relations = vec![make_relation("Alice", "likes", "coffee")];
491
492 graph
493 .add_relations(&relations, &entities, &filter)
494 .await
495 .unwrap();
496 graph
497 .add_relations(&relations, &entities, &filter)
498 .await
499 .unwrap();
500
501 let store = graph.relations.read().unwrap();
502 assert_eq!(store.len(), 1);
503 assert_eq!(store[0].mentions, 2);
504 }
505
506 #[tokio::test]
507 async fn search_by_relationship() {
508 let graph = FlatGraph::new();
509 let filter = test_filter("user1");
510
511 let entities = vec![
512 make_entity("Bob", "Person", None),
513 make_entity("Google", "Organization", None),
514 ];
515 let relations = vec![make_relation("Bob", "works_at", "Google")];
516
517 graph
518 .add_relations(&relations, &entities, &filter)
519 .await
520 .unwrap();
521
522 let results = graph.search("works", &filter, 10).await.unwrap();
523 assert_eq!(results.len(), 1);
524 }
525
526 #[tokio::test]
527 async fn search_respects_user_scope() {
528 let graph = FlatGraph::new();
529 let filter1 = test_filter("user1");
530 let filter2 = test_filter("user2");
531
532 let entities = vec![make_entity("X", "Other", None)];
533 let rels = vec![make_relation("X", "rel", "Y")];
534
535 graph
536 .add_relations(&rels, &entities, &filter1)
537 .await
538 .unwrap();
539
540 let r1 = graph.search("X", &filter1, 10).await.unwrap();
541 assert_eq!(r1.len(), 1);
542
543 let r2 = graph.search("X", &filter2, 10).await.unwrap();
544 assert_eq!(r2.len(), 0);
545 }
546
547 #[tokio::test]
548 async fn search_case_insensitive() {
549 let graph = FlatGraph::new();
550 let filter = test_filter("u");
551
552 let entities = vec![make_entity("Alice", "Person", None)];
553 let rels = vec![make_relation("Alice", "likes", "Coffee")];
554
555 graph
556 .add_relations(&rels, &entities, &filter)
557 .await
558 .unwrap();
559
560 assert_eq!(graph.search("alice", &filter, 10).await.unwrap().len(), 1);
561 assert_eq!(graph.search("COFFEE", &filter, 10).await.unwrap().len(), 1);
562 }
563
564 #[tokio::test]
565 async fn search_limit() {
566 let graph = FlatGraph::new();
567 let filter = test_filter("u");
568
569 let entities = vec![make_entity("A", "Other", None)];
570
571 for i in 0..10 {
572 let rels = vec![make_relation("A", &format!("rel_{i}"), &format!("B{i}"))];
573 graph
574 .add_relations(&rels, &entities, &filter)
575 .await
576 .unwrap();
577 }
578
579 let r = graph.search("A", &filter, 3).await.unwrap();
580 assert_eq!(r.len(), 3);
581 }
582
583 #[tokio::test]
584 async fn delete_all_by_user() {
585 let graph = FlatGraph::new();
586 let filter1 = test_filter("user1");
587 let filter2 = test_filter("user2");
588
589 let entities = vec![make_entity("X", "Other", None)];
590 let rels = vec![make_relation("X", "r", "Y")];
591
592 graph
593 .add_relations(&rels, &entities, &filter1)
594 .await
595 .unwrap();
596 graph
597 .add_relations(&rels, &entities, &filter2)
598 .await
599 .unwrap();
600
601 graph.delete_all(&filter1).await.unwrap();
602
603 let empty_filter = MemoryFilter::default();
604 let r = graph.search("X", &empty_filter, 10).await.unwrap();
605 assert_eq!(r.len(), 1);
606 }
607
608 #[tokio::test]
609 async fn delete_all_respects_agent_and_run_scope() {
610 let graph = FlatGraph::new();
611 let scoped_a = scoped_filter("user1", "agent-a", "run-a");
612 let scoped_b = scoped_filter("user1", "agent-b", "run-b");
613
614 let entities = vec![make_entity("Shared", "Other", None)];
615 let rels_a = vec![make_relation("Shared", "likes", "Rust")];
616 let rels_b = vec![make_relation("Shared", "likes", "Python")];
617
618 graph
619 .add_relations(&rels_a, &entities, &scoped_a)
620 .await
621 .unwrap();
622 graph
623 .add_relations(&rels_b, &entities, &scoped_b)
624 .await
625 .unwrap();
626
627 graph.delete_all(&scoped_a).await.unwrap();
628
629 let remaining_a = graph.search("Shared", &scoped_a, 10).await.unwrap();
630 let remaining_b = graph.search("Shared", &scoped_b, 10).await.unwrap();
631 assert!(remaining_a.is_empty());
632 assert_eq!(remaining_b.len(), 1);
633 assert_eq!(remaining_b[0].destination, "Python");
634 }
635
636 #[tokio::test]
637 async fn reset_clears_all() {
638 let graph = FlatGraph::new();
639 let filter = test_filter("u");
640
641 let entities = vec![make_entity("X", "Other", None)];
642 let rels = vec![make_relation("X", "r", "Y")];
643
644 graph
645 .add_relations(&rels, &entities, &filter)
646 .await
647 .unwrap();
648
649 graph.reset().await.unwrap();
650
651 let empty_filter = MemoryFilter::default();
652 assert!(
653 graph
654 .search("X", &empty_filter, 10)
655 .await
656 .unwrap()
657 .is_empty()
658 );
659 assert!(graph.entities.read().unwrap().is_empty());
660 }
661}