Skip to main content

mem7_graph/
flat.rs

1use std::sync::RwLock;
2
3use async_trait::async_trait;
4use mem7_core::MemoryFilter;
5use mem7_error::Result;
6
7use crate::GraphStore;
8use crate::types::{Entity, GraphSearchResult, Relation};
9
10#[derive(Debug, Clone)]
11struct StoredEntity {
12    name: String,
13    entity_type: String,
14    embedding: Option<Vec<f32>>,
15    #[allow(dead_code)]
16    created_at: Option<String>,
17    mentions: u32,
18    #[allow(dead_code)]
19    last_accessed_at: Option<String>,
20    user_id: Option<String>,
21    agent_id: Option<String>,
22    run_id: Option<String>,
23}
24
25#[derive(Debug, Clone)]
26struct StoredRelation {
27    source: String,
28    relationship: String,
29    destination: String,
30    created_at: Option<String>,
31    mentions: u32,
32    valid: bool,
33    last_accessed_at: Option<String>,
34    user_id: Option<String>,
35    agent_id: Option<String>,
36    run_id: Option<String>,
37}
38
39/// In-memory graph store for development and testing.
40pub struct FlatGraph {
41    entities: RwLock<Vec<StoredEntity>>,
42    relations: RwLock<Vec<StoredRelation>>,
43}
44
45impl FlatGraph {
46    pub fn new() -> Self {
47        Self {
48            entities: RwLock::new(Vec::new()),
49            relations: RwLock::new(Vec::new()),
50        }
51    }
52}
53
54impl Default for FlatGraph {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60fn matches_filter(
61    user_id: &Option<String>,
62    agent_id: &Option<String>,
63    run_id: &Option<String>,
64    filter: &MemoryFilter,
65) -> bool {
66    if let Some(uid) = &filter.user_id {
67        if user_id.as_deref() != Some(uid.as_str()) {
68            return false;
69        }
70    }
71    if let Some(aid) = &filter.agent_id {
72        if agent_id.as_deref() != Some(aid.as_str()) {
73            return false;
74        }
75    }
76    if let Some(rid) = &filter.run_id {
77        if run_id.as_deref() != Some(rid.as_str()) {
78            return false;
79        }
80    }
81    true
82}
83
84#[async_trait]
85impl GraphStore for FlatGraph {
86    async fn add_entities(&self, entities: &[Entity], filter: &MemoryFilter) -> Result<()> {
87        let mut store = self.entities.write().expect("entity lock poisoned");
88        for entity in entities {
89            if let Some(existing) = store.iter_mut().find(|e| {
90                e.name == entity.name && matches_filter(&e.user_id, &e.agent_id, &e.run_id, filter)
91            }) {
92                existing.mentions += 1;
93                if entity.embedding.is_some() {
94                    existing.embedding.clone_from(&entity.embedding);
95                }
96                if entity.entity_type != existing.entity_type {
97                    existing.entity_type.clone_from(&entity.entity_type);
98                }
99            } else {
100                store.push(StoredEntity {
101                    name: entity.name.clone(),
102                    entity_type: entity.entity_type.clone(),
103                    embedding: entity.embedding.clone(),
104                    created_at: entity.created_at.clone(),
105                    mentions: 1,
106                    last_accessed_at: entity.created_at.clone(),
107                    user_id: filter.user_id.clone(),
108                    agent_id: filter.agent_id.clone(),
109                    run_id: filter.run_id.clone(),
110                });
111            }
112        }
113        Ok(())
114    }
115
116    async fn add_relations(
117        &self,
118        relations: &[Relation],
119        entities: &[Entity],
120        filter: &MemoryFilter,
121    ) -> Result<()> {
122        self.add_entities(entities, filter).await?;
123
124        let mut store = self.relations.write().expect("relation lock poisoned");
125        for r in relations {
126            if let Some(existing) = store.iter_mut().find(|e| {
127                e.source == r.source
128                    && e.relationship == r.relationship
129                    && e.destination == r.destination
130                    && e.valid
131                    && matches_filter(&e.user_id, &e.agent_id, &e.run_id, filter)
132            }) {
133                existing.mentions += 1;
134            } else {
135                store.push(StoredRelation {
136                    source: r.source.clone(),
137                    relationship: r.relationship.clone(),
138                    destination: r.destination.clone(),
139                    created_at: r.created_at.clone(),
140                    mentions: 1,
141                    valid: true,
142                    last_accessed_at: r.created_at.clone(),
143                    user_id: filter.user_id.clone(),
144                    agent_id: filter.agent_id.clone(),
145                    run_id: filter.run_id.clone(),
146                });
147            }
148        }
149        Ok(())
150    }
151
152    async fn search(
153        &self,
154        query: &str,
155        filter: &MemoryFilter,
156        limit: usize,
157    ) -> Result<Vec<GraphSearchResult>> {
158        let store = self.relations.read().expect("relation lock poisoned");
159        let query_lower = query.to_lowercase();
160
161        let results: Vec<GraphSearchResult> = store
162            .iter()
163            .filter(|r| {
164                r.valid
165                    && matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter)
166                    && (r.source.to_lowercase().contains(&query_lower)
167                        || r.destination.to_lowercase().contains(&query_lower)
168                        || r.relationship.to_lowercase().contains(&query_lower))
169            })
170            .take(limit)
171            .map(|r| GraphSearchResult {
172                source: r.source.clone(),
173                relationship: r.relationship.clone(),
174                destination: r.destination.clone(),
175                score: None,
176                created_at: r.created_at.clone(),
177                mentions: Some(r.mentions),
178                last_accessed_at: r.last_accessed_at.clone(),
179            })
180            .collect();
181
182        Ok(results)
183    }
184
185    async fn search_by_embedding(
186        &self,
187        embedding: &[f32],
188        filter: &MemoryFilter,
189        threshold: f32,
190        limit: usize,
191    ) -> Result<Vec<GraphSearchResult>> {
192        let entities = self.entities.read().expect("entity lock poisoned");
193
194        // Find entities whose embedding is above the similarity threshold
195        let matched_names: Vec<(&str, f32)> = entities
196            .iter()
197            .filter(|e| matches_filter(&e.user_id, &e.agent_id, &e.run_id, filter))
198            .filter_map(|e| {
199                e.embedding.as_ref().map(|emb| {
200                    let sim = mem7_vector::cosine_similarity(emb, embedding);
201                    (e.name.as_str(), sim)
202                })
203            })
204            .filter(|(_, sim)| *sim >= threshold)
205            .collect();
206
207        if matched_names.is_empty() {
208            return Ok(Vec::new());
209        }
210
211        // 1-hop: collect all valid relations touching matched entities
212        let relations = self.relations.read().expect("relation lock poisoned");
213        let mut results: Vec<GraphSearchResult> = Vec::new();
214        let mut seen = std::collections::HashSet::new();
215
216        for (name, sim) in &matched_names {
217            for r in relations.iter() {
218                if !r.valid || !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter) {
219                    continue;
220                }
221                if r.source.as_str() == *name || r.destination.as_str() == *name {
222                    let key = (
223                        r.source.clone(),
224                        r.relationship.clone(),
225                        r.destination.clone(),
226                    );
227                    if seen.insert(key) {
228                        results.push(GraphSearchResult {
229                            source: r.source.clone(),
230                            relationship: r.relationship.clone(),
231                            destination: r.destination.clone(),
232                            score: Some(*sim),
233                            created_at: r.created_at.clone(),
234                            mentions: Some(r.mentions),
235                            last_accessed_at: r.last_accessed_at.clone(),
236                        });
237                    }
238                }
239            }
240        }
241
242        results.sort_by(|a, b| {
243            b.score
244                .unwrap_or(0.0)
245                .partial_cmp(&a.score.unwrap_or(0.0))
246                .unwrap_or(std::cmp::Ordering::Equal)
247        });
248        results.truncate(limit);
249
250        Ok(results)
251    }
252
253    async fn invalidate_relations(
254        &self,
255        triples: &[(String, String, String)],
256        filter: &MemoryFilter,
257    ) -> Result<()> {
258        let mut store = self.relations.write().expect("relation lock poisoned");
259        for r in store.iter_mut() {
260            if !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter) {
261                continue;
262            }
263            for (src, rel, dst) in triples {
264                if r.source == *src && r.relationship == *rel && r.destination == *dst && r.valid {
265                    r.valid = false;
266                }
267            }
268        }
269        Ok(())
270    }
271
272    async fn rehearse_relations(
273        &self,
274        triples: &[(String, String, String)],
275        filter: &MemoryFilter,
276        now: &str,
277    ) -> Result<()> {
278        let mut store = self.relations.write().expect("relation lock poisoned");
279        for r in store.iter_mut() {
280            if !r.valid || !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter) {
281                continue;
282            }
283            for (src, rel, dst) in triples {
284                if r.source == *src && r.relationship == *rel && r.destination == *dst {
285                    r.mentions += 1;
286                    r.last_accessed_at = Some(now.to_string());
287                }
288            }
289        }
290        Ok(())
291    }
292
293    async fn delete_all(&self, filter: &MemoryFilter) -> Result<()> {
294        let mut rel_store = self.relations.write().expect("relation lock poisoned");
295        rel_store.retain(|r| !matches_filter(&r.user_id, &r.agent_id, &r.run_id, filter));
296
297        let referenced_entities: std::collections::HashSet<String> = rel_store
298            .iter()
299            .flat_map(|r| [r.source.clone(), r.destination.clone()])
300            .collect();
301
302        let mut ent_store = self.entities.write().expect("entity lock poisoned");
303        ent_store.retain(|e| referenced_entities.contains(&e.name));
304
305        Ok(())
306    }
307
308    async fn reset(&self) -> Result<()> {
309        self.relations
310            .write()
311            .expect("relation lock poisoned")
312            .clear();
313        self.entities.write().expect("entity lock poisoned").clear();
314        Ok(())
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    fn test_filter(user_id: &str) -> MemoryFilter {
323        MemoryFilter {
324            user_id: Some(user_id.to_string()),
325            agent_id: None,
326            run_id: None,
327            metadata: None,
328        }
329    }
330
331    fn scoped_filter(user_id: &str, agent_id: &str, run_id: &str) -> MemoryFilter {
332        MemoryFilter {
333            user_id: Some(user_id.to_string()),
334            agent_id: Some(agent_id.to_string()),
335            run_id: Some(run_id.to_string()),
336            metadata: None,
337        }
338    }
339
340    fn make_entity(name: &str, etype: &str, embedding: Option<Vec<f32>>) -> Entity {
341        Entity {
342            name: name.into(),
343            entity_type: etype.into(),
344            embedding,
345            created_at: None,
346            mentions: 0,
347        }
348    }
349
350    fn make_relation(src: &str, rel: &str, dst: &str) -> Relation {
351        Relation {
352            source: src.into(),
353            relationship: rel.into(),
354            destination: dst.into(),
355            created_at: None,
356            mentions: 0,
357            valid: true,
358        }
359    }
360
361    #[tokio::test]
362    async fn add_and_search_relations() {
363        let graph = FlatGraph::new();
364        let filter = test_filter("user1");
365
366        let entities = vec![
367            make_entity("Alice", "Person", None),
368            make_entity("tennis", "Activity", None),
369        ];
370        let relations = vec![make_relation("Alice", "loves_playing", "tennis")];
371
372        graph
373            .add_relations(&relations, &entities, &filter)
374            .await
375            .unwrap();
376
377        let results = graph.search("Alice", &filter, 10).await.unwrap();
378        assert_eq!(results.len(), 1);
379        assert_eq!(results[0].source, "Alice");
380        assert_eq!(results[0].relationship, "loves_playing");
381        assert_eq!(results[0].destination, "tennis");
382    }
383
384    #[tokio::test]
385    async fn add_entities_stores_and_upserts() {
386        let graph = FlatGraph::new();
387        let filter = test_filter("user1");
388
389        let entities = vec![make_entity("Alice", "Person", Some(vec![1.0, 0.0]))];
390        graph.add_entities(&entities, &filter).await.unwrap();
391
392        // Adding again should increment mentions
393        graph.add_entities(&entities, &filter).await.unwrap();
394
395        let store = graph.entities.read().unwrap();
396        assert_eq!(store.len(), 1);
397        assert_eq!(store[0].mentions, 2);
398        assert!(store[0].embedding.is_some());
399    }
400
401    #[tokio::test]
402    async fn search_by_embedding_finds_related() {
403        let graph = FlatGraph::new();
404        let filter = test_filter("user1");
405
406        let entities = vec![
407            make_entity("Alice", "Person", Some(vec![1.0, 0.0, 0.0])),
408            make_entity("Bob", "Person", Some(vec![0.0, 1.0, 0.0])),
409        ];
410        let relations = vec![
411            make_relation("Alice", "friend_of", "Bob"),
412            make_relation("Alice", "likes", "tennis"),
413        ];
414
415        graph
416            .add_relations(&relations, &entities, &filter)
417            .await
418            .unwrap();
419
420        // Query embedding very similar to Alice's
421        let query_emb = vec![0.99, 0.01, 0.0];
422        let results = graph
423            .search_by_embedding(&query_emb, &filter, 0.7, 10)
424            .await
425            .unwrap();
426
427        // Should find relations touching Alice (2 relations)
428        assert_eq!(results.len(), 2);
429    }
430
431    #[tokio::test]
432    async fn search_by_embedding_respects_threshold() {
433        let graph = FlatGraph::new();
434        let filter = test_filter("user1");
435
436        let entities = vec![make_entity("Alice", "Person", Some(vec![1.0, 0.0]))];
437        let relations = vec![make_relation("Alice", "likes", "coffee")];
438
439        graph
440            .add_relations(&relations, &entities, &filter)
441            .await
442            .unwrap();
443
444        // Orthogonal embedding — should not match
445        let query_emb = vec![0.0, 1.0];
446        let results = graph
447            .search_by_embedding(&query_emb, &filter, 0.7, 10)
448            .await
449            .unwrap();
450        assert!(results.is_empty());
451    }
452
453    #[tokio::test]
454    async fn invalidate_relations_soft_deletes() {
455        let graph = FlatGraph::new();
456        let filter = test_filter("user1");
457
458        let entities = vec![make_entity("USER", "Person", None)];
459        let relations = vec![
460            make_relation("USER", "works_at", "Google"),
461            make_relation("USER", "lives_in", "NYC"),
462        ];
463
464        graph
465            .add_relations(&relations, &entities, &filter)
466            .await
467            .unwrap();
468
469        // Invalidate only works_at
470        graph
471            .invalidate_relations(
472                &[("USER".into(), "works_at".into(), "Google".into())],
473                &filter,
474            )
475            .await
476            .unwrap();
477
478        // Text search should only find lives_in (valid=true)
479        let results = graph.search("USER", &filter, 10).await.unwrap();
480        assert_eq!(results.len(), 1);
481        assert_eq!(results[0].relationship, "lives_in");
482    }
483
484    #[tokio::test]
485    async fn relation_dedup_increments_mentions() {
486        let graph = FlatGraph::new();
487        let filter = test_filter("user1");
488
489        let entities = vec![make_entity("Alice", "Person", None)];
490        let relations = vec![make_relation("Alice", "likes", "coffee")];
491
492        graph
493            .add_relations(&relations, &entities, &filter)
494            .await
495            .unwrap();
496        graph
497            .add_relations(&relations, &entities, &filter)
498            .await
499            .unwrap();
500
501        let store = graph.relations.read().unwrap();
502        assert_eq!(store.len(), 1);
503        assert_eq!(store[0].mentions, 2);
504    }
505
506    #[tokio::test]
507    async fn search_by_relationship() {
508        let graph = FlatGraph::new();
509        let filter = test_filter("user1");
510
511        let entities = vec![
512            make_entity("Bob", "Person", None),
513            make_entity("Google", "Organization", None),
514        ];
515        let relations = vec![make_relation("Bob", "works_at", "Google")];
516
517        graph
518            .add_relations(&relations, &entities, &filter)
519            .await
520            .unwrap();
521
522        let results = graph.search("works", &filter, 10).await.unwrap();
523        assert_eq!(results.len(), 1);
524    }
525
526    #[tokio::test]
527    async fn search_respects_user_scope() {
528        let graph = FlatGraph::new();
529        let filter1 = test_filter("user1");
530        let filter2 = test_filter("user2");
531
532        let entities = vec![make_entity("X", "Other", None)];
533        let rels = vec![make_relation("X", "rel", "Y")];
534
535        graph
536            .add_relations(&rels, &entities, &filter1)
537            .await
538            .unwrap();
539
540        let r1 = graph.search("X", &filter1, 10).await.unwrap();
541        assert_eq!(r1.len(), 1);
542
543        let r2 = graph.search("X", &filter2, 10).await.unwrap();
544        assert_eq!(r2.len(), 0);
545    }
546
547    #[tokio::test]
548    async fn search_case_insensitive() {
549        let graph = FlatGraph::new();
550        let filter = test_filter("u");
551
552        let entities = vec![make_entity("Alice", "Person", None)];
553        let rels = vec![make_relation("Alice", "likes", "Coffee")];
554
555        graph
556            .add_relations(&rels, &entities, &filter)
557            .await
558            .unwrap();
559
560        assert_eq!(graph.search("alice", &filter, 10).await.unwrap().len(), 1);
561        assert_eq!(graph.search("COFFEE", &filter, 10).await.unwrap().len(), 1);
562    }
563
564    #[tokio::test]
565    async fn search_limit() {
566        let graph = FlatGraph::new();
567        let filter = test_filter("u");
568
569        let entities = vec![make_entity("A", "Other", None)];
570
571        for i in 0..10 {
572            let rels = vec![make_relation("A", &format!("rel_{i}"), &format!("B{i}"))];
573            graph
574                .add_relations(&rels, &entities, &filter)
575                .await
576                .unwrap();
577        }
578
579        let r = graph.search("A", &filter, 3).await.unwrap();
580        assert_eq!(r.len(), 3);
581    }
582
583    #[tokio::test]
584    async fn delete_all_by_user() {
585        let graph = FlatGraph::new();
586        let filter1 = test_filter("user1");
587        let filter2 = test_filter("user2");
588
589        let entities = vec![make_entity("X", "Other", None)];
590        let rels = vec![make_relation("X", "r", "Y")];
591
592        graph
593            .add_relations(&rels, &entities, &filter1)
594            .await
595            .unwrap();
596        graph
597            .add_relations(&rels, &entities, &filter2)
598            .await
599            .unwrap();
600
601        graph.delete_all(&filter1).await.unwrap();
602
603        let empty_filter = MemoryFilter::default();
604        let r = graph.search("X", &empty_filter, 10).await.unwrap();
605        assert_eq!(r.len(), 1);
606    }
607
608    #[tokio::test]
609    async fn delete_all_respects_agent_and_run_scope() {
610        let graph = FlatGraph::new();
611        let scoped_a = scoped_filter("user1", "agent-a", "run-a");
612        let scoped_b = scoped_filter("user1", "agent-b", "run-b");
613
614        let entities = vec![make_entity("Shared", "Other", None)];
615        let rels_a = vec![make_relation("Shared", "likes", "Rust")];
616        let rels_b = vec![make_relation("Shared", "likes", "Python")];
617
618        graph
619            .add_relations(&rels_a, &entities, &scoped_a)
620            .await
621            .unwrap();
622        graph
623            .add_relations(&rels_b, &entities, &scoped_b)
624            .await
625            .unwrap();
626
627        graph.delete_all(&scoped_a).await.unwrap();
628
629        let remaining_a = graph.search("Shared", &scoped_a, 10).await.unwrap();
630        let remaining_b = graph.search("Shared", &scoped_b, 10).await.unwrap();
631        assert!(remaining_a.is_empty());
632        assert_eq!(remaining_b.len(), 1);
633        assert_eq!(remaining_b[0].destination, "Python");
634    }
635
636    #[tokio::test]
637    async fn reset_clears_all() {
638        let graph = FlatGraph::new();
639        let filter = test_filter("u");
640
641        let entities = vec![make_entity("X", "Other", None)];
642        let rels = vec![make_relation("X", "r", "Y")];
643
644        graph
645            .add_relations(&rels, &entities, &filter)
646            .await
647            .unwrap();
648
649        graph.reset().await.unwrap();
650
651        let empty_filter = MemoryFilter::default();
652        assert!(
653            graph
654                .search("X", &empty_filter, 10)
655                .await
656                .unwrap()
657                .is_empty()
658        );
659        assert!(graph.entities.read().unwrap().is_empty());
660    }
661}