Skip to main content

mem_graph/
memory.rs

1//! In-memory graph store with KNN search over embeddings.
2
3use mem_types::{
4    GraphDirection, GraphNeighbor, GraphPath, GraphStore, GraphStoreError, MemoryEdge, MemoryNode,
5    VecSearchHit,
6};
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11type ScopeIndex = HashMap<String, HashMap<String, Vec<String>>>;
12type EdgeIndex = HashMap<String, Vec<String>>;
13
14/// In-memory implementation of GraphStore.
15/// Nodes are keyed by id (globally unique); user/scope indexed for get_all_memory_items and search filtering.
16pub struct InMemoryGraphStore {
17    /// node_id -> node (embedding optional; used for search_by_embedding when present).
18    nodes: Arc<RwLock<HashMap<String, MemoryNode>>>,
19    /// user_name -> scope -> node_ids (for get_all_memory_items).
20    scope_index: Arc<RwLock<ScopeIndex>>,
21    /// edge_id -> edge.
22    edges: Arc<RwLock<HashMap<String, MemoryEdge>>>,
23    /// from_node_id -> edge_ids.
24    out_index: Arc<RwLock<EdgeIndex>>,
25    /// to_node_id -> edge_ids.
26    in_index: Arc<RwLock<EdgeIndex>>,
27}
28
29impl InMemoryGraphStore {
30    pub fn new() -> Self {
31        Self {
32            nodes: Arc::new(RwLock::new(HashMap::new())),
33            scope_index: Arc::new(RwLock::new(HashMap::new())),
34            edges: Arc::new(RwLock::new(HashMap::new())),
35            out_index: Arc::new(RwLock::new(HashMap::new())),
36            in_index: Arc::new(RwLock::new(HashMap::new())),
37        }
38    }
39
40    fn scope_for_node(metadata: &HashMap<String, serde_json::Value>) -> String {
41        metadata
42            .get("scope")
43            .and_then(|v| v.as_str())
44            .unwrap_or("LongTermMemory")
45            .to_string()
46    }
47
48    fn owner_from_metadata(metadata: &HashMap<String, serde_json::Value>) -> &str {
49        metadata
50            .get("user_name")
51            .and_then(|v| v.as_str())
52            .unwrap_or("")
53    }
54
55    fn add_edge_to_index(index: &mut EdgeIndex, node_id: &str, edge_id: &str) {
56        let list = index.entry(node_id.to_string()).or_default();
57        if !list.contains(&edge_id.to_string()) {
58            list.push(edge_id.to_string());
59        }
60    }
61
62    fn remove_edge_from_index(index: &mut EdgeIndex, node_id: &str, edge_id: &str) {
63        if let Some(list) = index.get_mut(node_id) {
64            list.retain(|x| x != edge_id);
65            if list.is_empty() {
66                index.remove(node_id);
67            }
68        }
69    }
70
71    fn add_edge_indexes(edge: &MemoryEdge, out_index: &mut EdgeIndex, in_index: &mut EdgeIndex) {
72        Self::add_edge_to_index(out_index, &edge.from, &edge.id);
73        Self::add_edge_to_index(in_index, &edge.to, &edge.id);
74    }
75
76    fn remove_edge_indexes(
77        edge: &MemoryEdge,
78        out_index: &mut EdgeIndex,
79        in_index: &mut EdgeIndex,
80    ) {
81        Self::remove_edge_from_index(out_index, &edge.from, &edge.id);
82        Self::remove_edge_from_index(in_index, &edge.to, &edge.id);
83    }
84
85    fn strip_embedding(mut node: MemoryNode, include_embedding: bool) -> MemoryNode {
86        if !include_embedding {
87            node.embedding = None;
88        }
89        node
90    }
91
92    fn is_tombstone(metadata: &HashMap<String, serde_json::Value>) -> bool {
93        metadata
94            .get("state")
95            .and_then(|v| v.as_str())
96            .unwrap_or("active")
97            == "tombstone"
98    }
99}
100
101impl Default for InMemoryGraphStore {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107#[async_trait::async_trait]
108impl GraphStore for InMemoryGraphStore {
109    async fn add_node(
110        &self,
111        id: &str,
112        memory: &str,
113        metadata: &HashMap<String, serde_json::Value>,
114        user_name: Option<&str>,
115    ) -> Result<(), GraphStoreError> {
116        let un = user_name.unwrap_or("");
117        let scope = Self::scope_for_node(metadata);
118        let mut meta = metadata.clone();
119        meta.insert(
120            "user_name".to_string(),
121            serde_json::Value::String(un.to_string()),
122        );
123        let node = MemoryNode {
124            id: id.to_string(),
125            memory: memory.to_string(),
126            metadata: meta,
127            embedding: None,
128        };
129        {
130            let mut nodes = self.nodes.write().await;
131            nodes.insert(id.to_string(), node);
132        }
133        {
134            let mut idx = self.scope_index.write().await;
135            let user_map = idx.entry(un.to_string()).or_default();
136            let scope_list = user_map.entry(scope).or_default();
137            if !scope_list.contains(&id.to_string()) {
138                scope_list.push(id.to_string());
139            }
140        }
141        Ok(())
142    }
143
144    async fn add_nodes_batch(
145        &self,
146        nodes: &[MemoryNode],
147        user_name: Option<&str>,
148    ) -> Result<(), GraphStoreError> {
149        let un = user_name.unwrap_or("");
150        let mut guard = self.nodes.write().await;
151        let mut idx_guard = self.scope_index.write().await;
152        let user_map = idx_guard.entry(un.to_string()).or_default();
153        for node in nodes {
154            let scope = Self::scope_for_node(&node.metadata);
155            let mut n = node.clone();
156            n.metadata.insert(
157                "user_name".to_string(),
158                serde_json::Value::String(un.to_string()),
159            );
160            guard.insert(n.id.clone(), n);
161            let scope_list = user_map.entry(scope).or_default();
162            if !scope_list.contains(&node.id) {
163                scope_list.push(node.id.clone());
164            }
165        }
166        Ok(())
167    }
168
169    async fn add_edges_batch(
170        &self,
171        edges: &[MemoryEdge],
172        user_name: Option<&str>,
173    ) -> Result<(), GraphStoreError> {
174        if edges.is_empty() {
175            return Ok(());
176        }
177        {
178            let nodes = self.nodes.read().await;
179            for edge in edges {
180                let from_node = nodes.get(&edge.from).ok_or_else(|| {
181                    GraphStoreError::Other(format!("from node not found: {}", edge.from))
182                })?;
183                let to_node = nodes.get(&edge.to).ok_or_else(|| {
184                    GraphStoreError::Other(format!("to node not found: {}", edge.to))
185                })?;
186                if let Some(un) = user_name {
187                    if Self::owner_from_metadata(&from_node.metadata) != un
188                        || Self::owner_from_metadata(&to_node.metadata) != un
189                    {
190                        return Err(GraphStoreError::Other(format!(
191                            "node not found or access denied for edge: {}",
192                            edge.id
193                        )));
194                    }
195                }
196            }
197        }
198
199        let un = user_name.unwrap_or("");
200        let mut edge_guard = self.edges.write().await;
201        let mut out_guard = self.out_index.write().await;
202        let mut in_guard = self.in_index.write().await;
203        for edge in edges {
204            let mut normalized = edge.clone();
205            normalized.metadata.insert(
206                "user_name".to_string(),
207                serde_json::Value::String(un.to_string()),
208            );
209            if let Some(old) = edge_guard.insert(normalized.id.clone(), normalized.clone()) {
210                Self::remove_edge_indexes(&old, &mut out_guard, &mut in_guard);
211            }
212            Self::add_edge_indexes(&normalized, &mut out_guard, &mut in_guard);
213        }
214        Ok(())
215    }
216
217    async fn get_node(
218        &self,
219        id: &str,
220        include_embedding: bool,
221    ) -> Result<Option<MemoryNode>, GraphStoreError> {
222        let guard = self.nodes.read().await;
223        Ok(guard
224            .get(id)
225            .cloned()
226            .map(|n| Self::strip_embedding(n, include_embedding)))
227    }
228
229    async fn get_nodes(
230        &self,
231        ids: &[String],
232        include_embedding: bool,
233    ) -> Result<Vec<MemoryNode>, GraphStoreError> {
234        let guard = self.nodes.read().await;
235        let mut result = Vec::with_capacity(ids.len());
236        for id in ids {
237            if let Some(node) = guard.get(id) {
238                result.push(Self::strip_embedding(node.clone(), include_embedding));
239            }
240        }
241        Ok(result)
242    }
243
244    async fn get_neighbors(
245        &self,
246        id: &str,
247        relation: Option<&str>,
248        direction: GraphDirection,
249        limit: usize,
250        include_embedding: bool,
251        user_name: Option<&str>,
252    ) -> Result<Vec<GraphNeighbor>, GraphStoreError> {
253        if limit == 0 {
254            return Ok(Vec::new());
255        }
256
257        {
258            let nodes = self.nodes.read().await;
259            let node = nodes
260                .get(id)
261                .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
262            if let Some(un) = user_name {
263                if Self::owner_from_metadata(&node.metadata) != un {
264                    return Err(GraphStoreError::Other(format!(
265                        "node not found or access denied: {}",
266                        id
267                    )));
268                }
269            }
270        }
271
272        let mut edge_ids: Vec<String> = Vec::new();
273        match direction {
274            GraphDirection::Outbound => {
275                let out_guard = self.out_index.read().await;
276                edge_ids.extend(out_guard.get(id).cloned().unwrap_or_default());
277            }
278            GraphDirection::Inbound => {
279                let in_guard = self.in_index.read().await;
280                edge_ids.extend(in_guard.get(id).cloned().unwrap_or_default());
281            }
282            GraphDirection::Both => {
283                let out_guard = self.out_index.read().await;
284                edge_ids.extend(out_guard.get(id).cloned().unwrap_or_default());
285                let in_guard = self.in_index.read().await;
286                edge_ids.extend(in_guard.get(id).cloned().unwrap_or_default());
287            }
288        }
289        if edge_ids.is_empty() {
290            return Ok(Vec::new());
291        }
292
293        let edge_guard = self.edges.read().await;
294        let node_guard = self.nodes.read().await;
295        let mut visited = HashSet::new();
296        let mut edges_to_visit: Vec<MemoryEdge> = Vec::new();
297
298        for edge_id in edge_ids {
299            if !visited.insert(edge_id.clone()) {
300                continue;
301            }
302            let edge = match edge_guard.get(&edge_id) {
303                Some(e) => e.clone(),
304                None => continue,
305            };
306            if let Some(un) = user_name {
307                if Self::owner_from_metadata(&edge.metadata) != un {
308                    continue;
309                }
310            }
311            if let Some(rel) = relation {
312                if edge.relation != rel {
313                    continue;
314                }
315            }
316            edges_to_visit.push(edge);
317        }
318        // Keep traversal deterministic across runtimes and hash-map ordering.
319        edges_to_visit.sort_by(|a, b| a.id.cmp(&b.id));
320
321        let mut result = Vec::new();
322        for edge in edges_to_visit {
323            if result.len() >= limit {
324                break;
325            }
326            let neighbor_id = match direction {
327                GraphDirection::Outbound => {
328                    if edge.from == id {
329                        &edge.to
330                    } else {
331                        continue;
332                    }
333                }
334                GraphDirection::Inbound => {
335                    if edge.to == id {
336                        &edge.from
337                    } else {
338                        continue;
339                    }
340                }
341                GraphDirection::Both => {
342                    if edge.from == id {
343                        &edge.to
344                    } else if edge.to == id {
345                        &edge.from
346                    } else {
347                        continue;
348                    }
349                }
350            };
351
352            let neighbor_node = match node_guard.get(neighbor_id) {
353                Some(n) => n,
354                None => continue,
355            };
356            if let Some(un) = user_name {
357                if Self::owner_from_metadata(&neighbor_node.metadata) != un {
358                    continue;
359                }
360            }
361
362            result.push(GraphNeighbor {
363                edge,
364                node: Self::strip_embedding(neighbor_node.clone(), include_embedding),
365            });
366        }
367        Ok(result)
368    }
369
370    async fn shortest_path(
371        &self,
372        source_id: &str,
373        target_id: &str,
374        relation: Option<&str>,
375        direction: GraphDirection,
376        max_depth: usize,
377        include_deleted: bool,
378        user_name: Option<&str>,
379    ) -> Result<Option<GraphPath>, GraphStoreError> {
380        if max_depth == 0 && source_id != target_id {
381            return Ok(None);
382        }
383
384        {
385            let nodes = self.nodes.read().await;
386            let source = nodes.get(source_id).ok_or_else(|| {
387                GraphStoreError::Other(format!("node not found: {}", source_id))
388            })?;
389            let target = nodes.get(target_id).ok_or_else(|| {
390                GraphStoreError::Other(format!("node not found: {}", target_id))
391            })?;
392            if let Some(un) = user_name {
393                if Self::owner_from_metadata(&source.metadata) != un
394                    || Self::owner_from_metadata(&target.metadata) != un
395                {
396                    return Err(GraphStoreError::Other(format!(
397                        "node not found or access denied: {} -> {}",
398                        source_id, target_id
399                    )));
400                }
401            }
402            if !include_deleted
403                && (Self::is_tombstone(&source.metadata) || Self::is_tombstone(&target.metadata))
404            {
405                return Ok(None);
406            }
407        }
408
409        if source_id == target_id {
410            return Ok(Some(GraphPath {
411                node_ids: vec![source_id.to_string()],
412                edges: Vec::new(),
413            }));
414        }
415
416        let edge_guard = self.edges.read().await;
417        let node_guard = self.nodes.read().await;
418        let out_guard = self.out_index.read().await;
419        let in_guard = self.in_index.read().await;
420
421        let mut queue: VecDeque<(String, usize)> = VecDeque::new();
422        let mut visited: HashSet<String> = HashSet::new();
423        let mut prev: HashMap<String, (String, MemoryEdge)> = HashMap::new();
424
425        queue.push_back((source_id.to_string(), 0));
426        visited.insert(source_id.to_string());
427
428        while let Some((current, depth)) = queue.pop_front() {
429            if depth >= max_depth {
430                continue;
431            }
432
433            let mut transitions: Vec<(String, MemoryEdge)> = Vec::new();
434            let mut edge_ids: Vec<String> = Vec::new();
435            match direction {
436                GraphDirection::Outbound => {
437                    edge_ids.extend(out_guard.get(&current).cloned().unwrap_or_default());
438                }
439                GraphDirection::Inbound => {
440                    edge_ids.extend(in_guard.get(&current).cloned().unwrap_or_default());
441                }
442                GraphDirection::Both => {
443                    edge_ids.extend(out_guard.get(&current).cloned().unwrap_or_default());
444                    edge_ids.extend(in_guard.get(&current).cloned().unwrap_or_default());
445                }
446            }
447
448            let mut dedup = HashSet::new();
449            for edge_id in edge_ids {
450                if !dedup.insert(edge_id.clone()) {
451                    continue;
452                }
453                let edge = match edge_guard.get(&edge_id) {
454                    Some(e) => e.clone(),
455                    None => continue,
456                };
457                if let Some(un) = user_name {
458                    if Self::owner_from_metadata(&edge.metadata) != un {
459                        continue;
460                    }
461                }
462                if let Some(rel) = relation {
463                    if edge.relation != rel {
464                        continue;
465                    }
466                }
467
468                let next = match direction {
469                    GraphDirection::Outbound => {
470                        if edge.from == current {
471                            Some(edge.to.clone())
472                        } else {
473                            None
474                        }
475                    }
476                    GraphDirection::Inbound => {
477                        if edge.to == current {
478                            Some(edge.from.clone())
479                        } else {
480                            None
481                        }
482                    }
483                    GraphDirection::Both => {
484                        if edge.from == current {
485                            Some(edge.to.clone())
486                        } else if edge.to == current {
487                            Some(edge.from.clone())
488                        } else {
489                            None
490                        }
491                    }
492                };
493                let Some(next_node_id) = next else { continue };
494                let Some(next_node) = node_guard.get(&next_node_id) else {
495                    continue;
496                };
497                if let Some(un) = user_name {
498                    if Self::owner_from_metadata(&next_node.metadata) != un {
499                        continue;
500                    }
501                }
502                if !include_deleted && Self::is_tombstone(&next_node.metadata) {
503                    continue;
504                }
505                transitions.push((next_node_id, edge));
506            }
507
508            transitions.sort_by(|a, b| a.1.id.cmp(&b.1.id).then_with(|| a.0.cmp(&b.0)));
509
510            for (next_node_id, edge) in transitions {
511                if visited.contains(&next_node_id) {
512                    continue;
513                }
514                visited.insert(next_node_id.clone());
515                prev.insert(next_node_id.clone(), (current.clone(), edge));
516                if next_node_id == target_id {
517                    let mut rev_nodes = vec![target_id.to_string()];
518                    let mut rev_edges: Vec<MemoryEdge> = Vec::new();
519                    let mut cursor = target_id.to_string();
520                    while cursor != source_id {
521                        let (p, e) = prev.get(&cursor).ok_or_else(|| {
522                            GraphStoreError::Other("path reconstruction failed".to_string())
523                        })?;
524                        rev_edges.push(e.clone());
525                        rev_nodes.push(p.clone());
526                        cursor = p.clone();
527                    }
528                    rev_nodes.reverse();
529                    rev_edges.reverse();
530                    return Ok(Some(GraphPath {
531                        node_ids: rev_nodes,
532                        edges: rev_edges,
533                    }));
534                }
535                queue.push_back((next_node_id, depth + 1));
536            }
537        }
538
539        Ok(None)
540    }
541
542    async fn find_paths(
543        &self,
544        source_id: &str,
545        target_id: &str,
546        relation: Option<&str>,
547        direction: GraphDirection,
548        max_depth: usize,
549        top_k: usize,
550        include_deleted: bool,
551        user_name: Option<&str>,
552    ) -> Result<Vec<GraphPath>, GraphStoreError> {
553        if top_k == 0 {
554            return Ok(Vec::new());
555        }
556        if max_depth == 0 && source_id != target_id {
557            return Ok(Vec::new());
558        }
559
560        {
561            let nodes = self.nodes.read().await;
562            let source = nodes.get(source_id).ok_or_else(|| {
563                GraphStoreError::Other(format!("node not found: {}", source_id))
564            })?;
565            let target = nodes.get(target_id).ok_or_else(|| {
566                GraphStoreError::Other(format!("node not found: {}", target_id))
567            })?;
568            if let Some(un) = user_name {
569                if Self::owner_from_metadata(&source.metadata) != un
570                    || Self::owner_from_metadata(&target.metadata) != un
571                {
572                    return Err(GraphStoreError::Other(format!(
573                        "node not found or access denied: {} -> {}",
574                        source_id, target_id
575                    )));
576                }
577            }
578            if !include_deleted
579                && (Self::is_tombstone(&source.metadata) || Self::is_tombstone(&target.metadata))
580            {
581                return Ok(Vec::new());
582            }
583        }
584
585        if source_id == target_id {
586            return Ok(vec![GraphPath {
587                node_ids: vec![source_id.to_string()],
588                edges: Vec::new(),
589            }]);
590        }
591
592        #[derive(Clone)]
593        struct PathState {
594            current: String,
595            node_ids: Vec<String>,
596            edges: Vec<MemoryEdge>,
597            visited: HashSet<String>,
598        }
599
600        let edge_guard = self.edges.read().await;
601        let node_guard = self.nodes.read().await;
602        let out_guard = self.out_index.read().await;
603        let in_guard = self.in_index.read().await;
604
605        let mut queue: VecDeque<PathState> = VecDeque::new();
606        let mut start_visited = HashSet::new();
607        start_visited.insert(source_id.to_string());
608        queue.push_back(PathState {
609            current: source_id.to_string(),
610            node_ids: vec![source_id.to_string()],
611            edges: Vec::new(),
612            visited: start_visited,
613        });
614
615        let mut results: Vec<GraphPath> = Vec::new();
616        while let Some(state) = queue.pop_front() {
617            if results.len() >= top_k {
618                break;
619            }
620            if state.current == target_id {
621                results.push(GraphPath {
622                    node_ids: state.node_ids.clone(),
623                    edges: state.edges.clone(),
624                });
625                continue;
626            }
627            if state.edges.len() >= max_depth {
628                continue;
629            }
630
631            let mut edge_ids: Vec<String> = Vec::new();
632            match direction {
633                GraphDirection::Outbound => {
634                    edge_ids.extend(out_guard.get(&state.current).cloned().unwrap_or_default());
635                }
636                GraphDirection::Inbound => {
637                    edge_ids.extend(in_guard.get(&state.current).cloned().unwrap_or_default());
638                }
639                GraphDirection::Both => {
640                    edge_ids.extend(out_guard.get(&state.current).cloned().unwrap_or_default());
641                    edge_ids.extend(in_guard.get(&state.current).cloned().unwrap_or_default());
642                }
643            }
644
645            let mut dedup = HashSet::new();
646            let mut transitions: Vec<(String, MemoryEdge)> = Vec::new();
647            for edge_id in edge_ids {
648                if !dedup.insert(edge_id.clone()) {
649                    continue;
650                }
651                let edge = match edge_guard.get(&edge_id) {
652                    Some(e) => e.clone(),
653                    None => continue,
654                };
655                if let Some(un) = user_name {
656                    if Self::owner_from_metadata(&edge.metadata) != un {
657                        continue;
658                    }
659                }
660                if let Some(rel) = relation {
661                    if edge.relation != rel {
662                        continue;
663                    }
664                }
665
666                let next = match direction {
667                    GraphDirection::Outbound => {
668                        if edge.from == state.current {
669                            Some(edge.to.clone())
670                        } else {
671                            None
672                        }
673                    }
674                    GraphDirection::Inbound => {
675                        if edge.to == state.current {
676                            Some(edge.from.clone())
677                        } else {
678                            None
679                        }
680                    }
681                    GraphDirection::Both => {
682                        if edge.from == state.current {
683                            Some(edge.to.clone())
684                        } else if edge.to == state.current {
685                            Some(edge.from.clone())
686                        } else {
687                            None
688                        }
689                    }
690                };
691
692                let Some(next_node_id) = next else { continue };
693                if state.visited.contains(&next_node_id) {
694                    continue;
695                }
696                let Some(next_node) = node_guard.get(&next_node_id) else {
697                    continue;
698                };
699                if let Some(un) = user_name {
700                    if Self::owner_from_metadata(&next_node.metadata) != un {
701                        continue;
702                    }
703                }
704                if !include_deleted && Self::is_tombstone(&next_node.metadata) {
705                    continue;
706                }
707                transitions.push((next_node_id, edge));
708            }
709            transitions.sort_by(|a, b| a.1.id.cmp(&b.1.id).then_with(|| a.0.cmp(&b.0)));
710
711            for (next_node_id, edge) in transitions {
712                let mut next_state = state.clone();
713                next_state.current = next_node_id.clone();
714                next_state.node_ids.push(next_node_id.clone());
715                next_state.edges.push(edge);
716                next_state.visited.insert(next_node_id);
717                queue.push_back(next_state);
718            }
719        }
720
721        Ok(results)
722    }
723
724    async fn search_by_embedding(
725        &self,
726        vector: &[f32],
727        top_k: usize,
728        user_name: Option<&str>,
729    ) -> Result<Vec<VecSearchHit>, GraphStoreError> {
730        let guard = self.nodes.read().await;
731        let un = user_name.unwrap_or("");
732        let mut candidates: Vec<(String, f64)> = Vec::new();
733        for node in guard.values() {
734            if !un.is_empty() {
735                let node_user = Self::owner_from_metadata(&node.metadata);
736                if node_user != un {
737                    continue;
738                }
739            }
740            let emb = match &node.embedding {
741                Some(e) => e,
742                None => continue,
743            };
744            if emb.len() != vector.len() {
745                continue;
746            }
747            let score = cosine_similarity(vector, emb);
748            candidates.push((node.id.clone(), score));
749        }
750        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
751        let hits = candidates
752            .into_iter()
753            .take(top_k)
754            .map(|(id, score)| VecSearchHit { id, score })
755            .collect();
756        Ok(hits)
757    }
758
759    async fn get_all_memory_items(
760        &self,
761        scope: &str,
762        user_name: &str,
763        include_embedding: bool,
764    ) -> Result<Vec<MemoryNode>, GraphStoreError> {
765        let ids = {
766            let idx = self.scope_index.read().await;
767            idx.get(user_name)
768                .and_then(|m| m.get(scope))
769                .cloned()
770                .unwrap_or_default()
771        };
772        let mut nodes = self.get_nodes(&ids, include_embedding).await?;
773        nodes.sort_by(|a, b| a.id.cmp(&b.id));
774        Ok(nodes)
775    }
776
777    async fn update_node(
778        &self,
779        id: &str,
780        fields: &HashMap<String, serde_json::Value>,
781        user_name: Option<&str>,
782    ) -> Result<(), GraphStoreError> {
783        let mut guard = self.nodes.write().await;
784        let node = guard
785            .get_mut(id)
786            .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
787        if let Some(un) = user_name {
788            let node_owner = Self::owner_from_metadata(&node.metadata);
789            if node_owner != un {
790                return Err(GraphStoreError::Other(format!(
791                    "node not found or access denied: {}",
792                    id
793                )));
794            }
795        }
796        for (k, v) in fields {
797            if k == "memory" {
798                node.memory = v.as_str().unwrap_or("").to_string();
799            } else {
800                node.metadata.insert(k.clone(), v.clone());
801            }
802        }
803        Ok(())
804    }
805
806    async fn delete_node(&self, id: &str, user_name: Option<&str>) -> Result<(), GraphStoreError> {
807        {
808            let nodes = self.nodes.read().await;
809            let node = nodes
810                .get(id)
811                .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
812            if let Some(un) = user_name {
813                let node_owner = Self::owner_from_metadata(&node.metadata);
814                if node_owner != un {
815                    return Err(GraphStoreError::Other(format!(
816                        "node not found or access denied: {}",
817                        id
818                    )));
819                }
820            }
821        }
822        {
823            let mut nodes = self.nodes.write().await;
824            nodes
825                .remove(id)
826                .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
827        }
828        {
829            let mut idx = self.scope_index.write().await;
830            for scope_map in idx.values_mut() {
831                for list in scope_map.values_mut() {
832                    list.retain(|x| x != id);
833                }
834            }
835        }
836        self.delete_edges_by_node(id, user_name).await?;
837        Ok(())
838    }
839
840    async fn delete_edges_by_node(
841        &self,
842        id: &str,
843        user_name: Option<&str>,
844    ) -> Result<usize, GraphStoreError> {
845        let mut edge_ids: HashSet<String> = HashSet::new();
846        {
847            let out_guard = self.out_index.read().await;
848            edge_ids.extend(out_guard.get(id).cloned().unwrap_or_default());
849        }
850        {
851            let in_guard = self.in_index.read().await;
852            edge_ids.extend(in_guard.get(id).cloned().unwrap_or_default());
853        }
854
855        if edge_ids.is_empty() {
856            return Ok(0);
857        }
858
859        let mut edge_guard = self.edges.write().await;
860        let mut out_guard = self.out_index.write().await;
861        let mut in_guard = self.in_index.write().await;
862
863        if let Some(un) = user_name {
864            for edge_id in &edge_ids {
865                if let Some(edge) = edge_guard.get(edge_id) {
866                    if Self::owner_from_metadata(&edge.metadata) != un {
867                        return Err(GraphStoreError::Other(format!(
868                            "edge not found or access denied: {}",
869                            edge_id
870                        )));
871                    }
872                }
873            }
874        }
875
876        let mut deleted = 0usize;
877        for edge_id in edge_ids {
878            if let Some(edge) = edge_guard.remove(&edge_id) {
879                Self::remove_edge_indexes(&edge, &mut out_guard, &mut in_guard);
880                deleted += 1;
881            }
882        }
883        Ok(deleted)
884    }
885}
886
887fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
888    if a.len() != b.len() || a.is_empty() {
889        return 0.0;
890    }
891    let dot: f64 = a
892        .iter()
893        .zip(b.iter())
894        .map(|(x, y)| (*x as f64) * (*y as f64))
895        .sum();
896    let na: f64 = a.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
897    let nb: f64 = b.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
898    if na == 0.0 || nb == 0.0 {
899        return 0.0;
900    }
901    dot / (na * nb)
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907
908    #[tokio::test]
909    async fn neighbors_are_deterministic_and_limited() {
910        let store = InMemoryGraphStore::new();
911
912        let mut meta = HashMap::new();
913        meta.insert(
914            "scope".to_string(),
915            serde_json::Value::String("LongTermMemory".to_string()),
916        );
917
918        store.add_node("n0", "root", &meta, Some("u1")).await.unwrap();
919        store.add_node("n1", "node1", &meta, Some("u1")).await.unwrap();
920        store.add_node("n2", "node2", &meta, Some("u1")).await.unwrap();
921
922        store
923            .add_edges_batch(
924                &[
925                    MemoryEdge {
926                        id: "e2".to_string(),
927                        from: "n0".to_string(),
928                        to: "n2".to_string(),
929                        relation: "related_to".to_string(),
930                        metadata: HashMap::new(),
931                    },
932                    MemoryEdge {
933                        id: "e1".to_string(),
934                        from: "n0".to_string(),
935                        to: "n1".to_string(),
936                        relation: "related_to".to_string(),
937                        metadata: HashMap::new(),
938                    },
939                ],
940                Some("u1"),
941            )
942            .await
943            .unwrap();
944
945        let all = store
946            .get_neighbors(
947                "n0",
948                Some("related_to"),
949                GraphDirection::Outbound,
950                10,
951                false,
952                Some("u1"),
953            )
954            .await
955            .unwrap();
956        assert_eq!(all.len(), 2);
957        assert_eq!(all[0].edge.id, "e1");
958        assert_eq!(all[1].edge.id, "e2");
959
960        let limited = store
961            .get_neighbors(
962                "n0",
963                Some("related_to"),
964                GraphDirection::Outbound,
965                1,
966                false,
967                Some("u1"),
968            )
969            .await
970            .unwrap();
971        assert_eq!(limited.len(), 1);
972        assert_eq!(limited[0].edge.id, "e1");
973    }
974
975    #[tokio::test]
976    async fn shortest_path_finds_min_hops() {
977        let store = InMemoryGraphStore::new();
978
979        let mut meta = HashMap::new();
980        meta.insert(
981            "scope".to_string(),
982            serde_json::Value::String("LongTermMemory".to_string()),
983        );
984
985        store.add_node("a", "A", &meta, Some("u1")).await.unwrap();
986        store.add_node("b", "B", &meta, Some("u1")).await.unwrap();
987        store.add_node("c", "C", &meta, Some("u1")).await.unwrap();
988        store.add_node("d", "D", &meta, Some("u1")).await.unwrap();
989
990        store
991            .add_edges_batch(
992                &[
993                    MemoryEdge {
994                        id: "e_ab".to_string(),
995                        from: "a".to_string(),
996                        to: "b".to_string(),
997                        relation: "related_to".to_string(),
998                        metadata: HashMap::new(),
999                    },
1000                    MemoryEdge {
1001                        id: "e_bc".to_string(),
1002                        from: "b".to_string(),
1003                        to: "c".to_string(),
1004                        relation: "related_to".to_string(),
1005                        metadata: HashMap::new(),
1006                    },
1007                    MemoryEdge {
1008                        id: "e_ad".to_string(),
1009                        from: "a".to_string(),
1010                        to: "d".to_string(),
1011                        relation: "related_to".to_string(),
1012                        metadata: HashMap::new(),
1013                    },
1014                    MemoryEdge {
1015                        id: "e_dc".to_string(),
1016                        from: "d".to_string(),
1017                        to: "c".to_string(),
1018                        relation: "related_to".to_string(),
1019                        metadata: HashMap::new(),
1020                    },
1021                ],
1022                Some("u1"),
1023            )
1024            .await
1025            .unwrap();
1026
1027        let path = store
1028            .shortest_path(
1029                "a",
1030                "c",
1031                Some("related_to"),
1032                GraphDirection::Outbound,
1033                3,
1034                false,
1035                Some("u1"),
1036            )
1037            .await
1038            .unwrap()
1039            .unwrap();
1040        assert_eq!(path.node_ids.first().map(String::as_str), Some("a"));
1041        assert_eq!(path.node_ids.last().map(String::as_str), Some("c"));
1042        assert_eq!(path.edges.len(), 2);
1043    }
1044
1045    #[tokio::test]
1046    async fn find_paths_returns_top_k_shortest() {
1047        let store = InMemoryGraphStore::new();
1048
1049        let mut meta = HashMap::new();
1050        meta.insert(
1051            "scope".to_string(),
1052            serde_json::Value::String("LongTermMemory".to_string()),
1053        );
1054
1055        for id in ["s", "a", "b", "t"] {
1056            store.add_node(id, id, &meta, Some("u1")).await.unwrap();
1057        }
1058        store
1059            .add_edges_batch(
1060                &[
1061                    MemoryEdge {
1062                        id: "e_sa".to_string(),
1063                        from: "s".to_string(),
1064                        to: "a".to_string(),
1065                        relation: "r".to_string(),
1066                        metadata: HashMap::new(),
1067                    },
1068                    MemoryEdge {
1069                        id: "e_at".to_string(),
1070                        from: "a".to_string(),
1071                        to: "t".to_string(),
1072                        relation: "r".to_string(),
1073                        metadata: HashMap::new(),
1074                    },
1075                    MemoryEdge {
1076                        id: "e_sb".to_string(),
1077                        from: "s".to_string(),
1078                        to: "b".to_string(),
1079                        relation: "r".to_string(),
1080                        metadata: HashMap::new(),
1081                    },
1082                    MemoryEdge {
1083                        id: "e_bt".to_string(),
1084                        from: "b".to_string(),
1085                        to: "t".to_string(),
1086                        relation: "r".to_string(),
1087                        metadata: HashMap::new(),
1088                    },
1089                ],
1090                Some("u1"),
1091            )
1092            .await
1093            .unwrap();
1094
1095        let paths = store
1096            .find_paths("s", "t", Some("r"), GraphDirection::Outbound, 3, 2, false, Some("u1"))
1097            .await
1098            .unwrap();
1099        assert_eq!(paths.len(), 2);
1100        assert_eq!(paths[0].edges.len(), 2);
1101        assert_eq!(paths[1].edges.len(), 2);
1102    }
1103}