1use mem_types::{
4 GraphDirection, GraphNeighbor, GraphPath, GraphStore, GraphStoreError, MemoryEdge, MemoryNode,
5 VecSearchHit,
6};
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11type ScopeIndex = HashMap<String, HashMap<String, Vec<String>>>;
12type EdgeIndex = HashMap<String, Vec<String>>;
13
14pub struct InMemoryGraphStore {
17 nodes: Arc<RwLock<HashMap<String, MemoryNode>>>,
19 scope_index: Arc<RwLock<ScopeIndex>>,
21 edges: Arc<RwLock<HashMap<String, MemoryEdge>>>,
23 out_index: Arc<RwLock<EdgeIndex>>,
25 in_index: Arc<RwLock<EdgeIndex>>,
27}
28
29impl InMemoryGraphStore {
30 pub fn new() -> Self {
31 Self {
32 nodes: Arc::new(RwLock::new(HashMap::new())),
33 scope_index: Arc::new(RwLock::new(HashMap::new())),
34 edges: Arc::new(RwLock::new(HashMap::new())),
35 out_index: Arc::new(RwLock::new(HashMap::new())),
36 in_index: Arc::new(RwLock::new(HashMap::new())),
37 }
38 }
39
40 fn scope_for_node(metadata: &HashMap<String, serde_json::Value>) -> String {
41 metadata
42 .get("scope")
43 .and_then(|v| v.as_str())
44 .unwrap_or("LongTermMemory")
45 .to_string()
46 }
47
48 fn owner_from_metadata(metadata: &HashMap<String, serde_json::Value>) -> &str {
49 metadata
50 .get("user_name")
51 .and_then(|v| v.as_str())
52 .unwrap_or("")
53 }
54
55 fn add_edge_to_index(index: &mut EdgeIndex, node_id: &str, edge_id: &str) {
56 let list = index.entry(node_id.to_string()).or_default();
57 if !list.contains(&edge_id.to_string()) {
58 list.push(edge_id.to_string());
59 }
60 }
61
62 fn remove_edge_from_index(index: &mut EdgeIndex, node_id: &str, edge_id: &str) {
63 if let Some(list) = index.get_mut(node_id) {
64 list.retain(|x| x != edge_id);
65 if list.is_empty() {
66 index.remove(node_id);
67 }
68 }
69 }
70
71 fn add_edge_indexes(edge: &MemoryEdge, out_index: &mut EdgeIndex, in_index: &mut EdgeIndex) {
72 Self::add_edge_to_index(out_index, &edge.from, &edge.id);
73 Self::add_edge_to_index(in_index, &edge.to, &edge.id);
74 }
75
76 fn remove_edge_indexes(
77 edge: &MemoryEdge,
78 out_index: &mut EdgeIndex,
79 in_index: &mut EdgeIndex,
80 ) {
81 Self::remove_edge_from_index(out_index, &edge.from, &edge.id);
82 Self::remove_edge_from_index(in_index, &edge.to, &edge.id);
83 }
84
85 fn strip_embedding(mut node: MemoryNode, include_embedding: bool) -> MemoryNode {
86 if !include_embedding {
87 node.embedding = None;
88 }
89 node
90 }
91
92 fn is_tombstone(metadata: &HashMap<String, serde_json::Value>) -> bool {
93 metadata
94 .get("state")
95 .and_then(|v| v.as_str())
96 .unwrap_or("active")
97 == "tombstone"
98 }
99}
100
101impl Default for InMemoryGraphStore {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107#[async_trait::async_trait]
108impl GraphStore for InMemoryGraphStore {
109 async fn add_node(
110 &self,
111 id: &str,
112 memory: &str,
113 metadata: &HashMap<String, serde_json::Value>,
114 user_name: Option<&str>,
115 ) -> Result<(), GraphStoreError> {
116 let un = user_name.unwrap_or("");
117 let scope = Self::scope_for_node(metadata);
118 let mut meta = metadata.clone();
119 meta.insert(
120 "user_name".to_string(),
121 serde_json::Value::String(un.to_string()),
122 );
123 let node = MemoryNode {
124 id: id.to_string(),
125 memory: memory.to_string(),
126 metadata: meta,
127 embedding: None,
128 };
129 {
130 let mut nodes = self.nodes.write().await;
131 nodes.insert(id.to_string(), node);
132 }
133 {
134 let mut idx = self.scope_index.write().await;
135 let user_map = idx.entry(un.to_string()).or_default();
136 let scope_list = user_map.entry(scope).or_default();
137 if !scope_list.contains(&id.to_string()) {
138 scope_list.push(id.to_string());
139 }
140 }
141 Ok(())
142 }
143
144 async fn add_nodes_batch(
145 &self,
146 nodes: &[MemoryNode],
147 user_name: Option<&str>,
148 ) -> Result<(), GraphStoreError> {
149 let un = user_name.unwrap_or("");
150 let mut guard = self.nodes.write().await;
151 let mut idx_guard = self.scope_index.write().await;
152 let user_map = idx_guard.entry(un.to_string()).or_default();
153 for node in nodes {
154 let scope = Self::scope_for_node(&node.metadata);
155 let mut n = node.clone();
156 n.metadata.insert(
157 "user_name".to_string(),
158 serde_json::Value::String(un.to_string()),
159 );
160 guard.insert(n.id.clone(), n);
161 let scope_list = user_map.entry(scope).or_default();
162 if !scope_list.contains(&node.id) {
163 scope_list.push(node.id.clone());
164 }
165 }
166 Ok(())
167 }
168
169 async fn add_edges_batch(
170 &self,
171 edges: &[MemoryEdge],
172 user_name: Option<&str>,
173 ) -> Result<(), GraphStoreError> {
174 if edges.is_empty() {
175 return Ok(());
176 }
177 {
178 let nodes = self.nodes.read().await;
179 for edge in edges {
180 let from_node = nodes.get(&edge.from).ok_or_else(|| {
181 GraphStoreError::Other(format!("from node not found: {}", edge.from))
182 })?;
183 let to_node = nodes.get(&edge.to).ok_or_else(|| {
184 GraphStoreError::Other(format!("to node not found: {}", edge.to))
185 })?;
186 if let Some(un) = user_name {
187 if Self::owner_from_metadata(&from_node.metadata) != un
188 || Self::owner_from_metadata(&to_node.metadata) != un
189 {
190 return Err(GraphStoreError::Other(format!(
191 "node not found or access denied for edge: {}",
192 edge.id
193 )));
194 }
195 }
196 }
197 }
198
199 let un = user_name.unwrap_or("");
200 let mut edge_guard = self.edges.write().await;
201 let mut out_guard = self.out_index.write().await;
202 let mut in_guard = self.in_index.write().await;
203 for edge in edges {
204 let mut normalized = edge.clone();
205 normalized.metadata.insert(
206 "user_name".to_string(),
207 serde_json::Value::String(un.to_string()),
208 );
209 if let Some(old) = edge_guard.insert(normalized.id.clone(), normalized.clone()) {
210 Self::remove_edge_indexes(&old, &mut out_guard, &mut in_guard);
211 }
212 Self::add_edge_indexes(&normalized, &mut out_guard, &mut in_guard);
213 }
214 Ok(())
215 }
216
217 async fn get_node(
218 &self,
219 id: &str,
220 include_embedding: bool,
221 ) -> Result<Option<MemoryNode>, GraphStoreError> {
222 let guard = self.nodes.read().await;
223 Ok(guard
224 .get(id)
225 .cloned()
226 .map(|n| Self::strip_embedding(n, include_embedding)))
227 }
228
229 async fn get_nodes(
230 &self,
231 ids: &[String],
232 include_embedding: bool,
233 ) -> Result<Vec<MemoryNode>, GraphStoreError> {
234 let guard = self.nodes.read().await;
235 let mut result = Vec::with_capacity(ids.len());
236 for id in ids {
237 if let Some(node) = guard.get(id) {
238 result.push(Self::strip_embedding(node.clone(), include_embedding));
239 }
240 }
241 Ok(result)
242 }
243
244 async fn get_neighbors(
245 &self,
246 id: &str,
247 relation: Option<&str>,
248 direction: GraphDirection,
249 limit: usize,
250 include_embedding: bool,
251 user_name: Option<&str>,
252 ) -> Result<Vec<GraphNeighbor>, GraphStoreError> {
253 if limit == 0 {
254 return Ok(Vec::new());
255 }
256
257 {
258 let nodes = self.nodes.read().await;
259 let node = nodes
260 .get(id)
261 .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
262 if let Some(un) = user_name {
263 if Self::owner_from_metadata(&node.metadata) != un {
264 return Err(GraphStoreError::Other(format!(
265 "node not found or access denied: {}",
266 id
267 )));
268 }
269 }
270 }
271
272 let mut edge_ids: Vec<String> = Vec::new();
273 match direction {
274 GraphDirection::Outbound => {
275 let out_guard = self.out_index.read().await;
276 edge_ids.extend(out_guard.get(id).cloned().unwrap_or_default());
277 }
278 GraphDirection::Inbound => {
279 let in_guard = self.in_index.read().await;
280 edge_ids.extend(in_guard.get(id).cloned().unwrap_or_default());
281 }
282 GraphDirection::Both => {
283 let out_guard = self.out_index.read().await;
284 edge_ids.extend(out_guard.get(id).cloned().unwrap_or_default());
285 let in_guard = self.in_index.read().await;
286 edge_ids.extend(in_guard.get(id).cloned().unwrap_or_default());
287 }
288 }
289 if edge_ids.is_empty() {
290 return Ok(Vec::new());
291 }
292
293 let edge_guard = self.edges.read().await;
294 let node_guard = self.nodes.read().await;
295 let mut visited = HashSet::new();
296 let mut edges_to_visit: Vec<MemoryEdge> = Vec::new();
297
298 for edge_id in edge_ids {
299 if !visited.insert(edge_id.clone()) {
300 continue;
301 }
302 let edge = match edge_guard.get(&edge_id) {
303 Some(e) => e.clone(),
304 None => continue,
305 };
306 if let Some(un) = user_name {
307 if Self::owner_from_metadata(&edge.metadata) != un {
308 continue;
309 }
310 }
311 if let Some(rel) = relation {
312 if edge.relation != rel {
313 continue;
314 }
315 }
316 edges_to_visit.push(edge);
317 }
318 edges_to_visit.sort_by(|a, b| a.id.cmp(&b.id));
320
321 let mut result = Vec::new();
322 for edge in edges_to_visit {
323 if result.len() >= limit {
324 break;
325 }
326 let neighbor_id = match direction {
327 GraphDirection::Outbound => {
328 if edge.from == id {
329 &edge.to
330 } else {
331 continue;
332 }
333 }
334 GraphDirection::Inbound => {
335 if edge.to == id {
336 &edge.from
337 } else {
338 continue;
339 }
340 }
341 GraphDirection::Both => {
342 if edge.from == id {
343 &edge.to
344 } else if edge.to == id {
345 &edge.from
346 } else {
347 continue;
348 }
349 }
350 };
351
352 let neighbor_node = match node_guard.get(neighbor_id) {
353 Some(n) => n,
354 None => continue,
355 };
356 if let Some(un) = user_name {
357 if Self::owner_from_metadata(&neighbor_node.metadata) != un {
358 continue;
359 }
360 }
361
362 result.push(GraphNeighbor {
363 edge,
364 node: Self::strip_embedding(neighbor_node.clone(), include_embedding),
365 });
366 }
367 Ok(result)
368 }
369
370 async fn shortest_path(
371 &self,
372 source_id: &str,
373 target_id: &str,
374 relation: Option<&str>,
375 direction: GraphDirection,
376 max_depth: usize,
377 include_deleted: bool,
378 user_name: Option<&str>,
379 ) -> Result<Option<GraphPath>, GraphStoreError> {
380 if max_depth == 0 && source_id != target_id {
381 return Ok(None);
382 }
383
384 {
385 let nodes = self.nodes.read().await;
386 let source = nodes.get(source_id).ok_or_else(|| {
387 GraphStoreError::Other(format!("node not found: {}", source_id))
388 })?;
389 let target = nodes.get(target_id).ok_or_else(|| {
390 GraphStoreError::Other(format!("node not found: {}", target_id))
391 })?;
392 if let Some(un) = user_name {
393 if Self::owner_from_metadata(&source.metadata) != un
394 || Self::owner_from_metadata(&target.metadata) != un
395 {
396 return Err(GraphStoreError::Other(format!(
397 "node not found or access denied: {} -> {}",
398 source_id, target_id
399 )));
400 }
401 }
402 if !include_deleted
403 && (Self::is_tombstone(&source.metadata) || Self::is_tombstone(&target.metadata))
404 {
405 return Ok(None);
406 }
407 }
408
409 if source_id == target_id {
410 return Ok(Some(GraphPath {
411 node_ids: vec![source_id.to_string()],
412 edges: Vec::new(),
413 }));
414 }
415
416 let edge_guard = self.edges.read().await;
417 let node_guard = self.nodes.read().await;
418 let out_guard = self.out_index.read().await;
419 let in_guard = self.in_index.read().await;
420
421 let mut queue: VecDeque<(String, usize)> = VecDeque::new();
422 let mut visited: HashSet<String> = HashSet::new();
423 let mut prev: HashMap<String, (String, MemoryEdge)> = HashMap::new();
424
425 queue.push_back((source_id.to_string(), 0));
426 visited.insert(source_id.to_string());
427
428 while let Some((current, depth)) = queue.pop_front() {
429 if depth >= max_depth {
430 continue;
431 }
432
433 let mut transitions: Vec<(String, MemoryEdge)> = Vec::new();
434 let mut edge_ids: Vec<String> = Vec::new();
435 match direction {
436 GraphDirection::Outbound => {
437 edge_ids.extend(out_guard.get(¤t).cloned().unwrap_or_default());
438 }
439 GraphDirection::Inbound => {
440 edge_ids.extend(in_guard.get(¤t).cloned().unwrap_or_default());
441 }
442 GraphDirection::Both => {
443 edge_ids.extend(out_guard.get(¤t).cloned().unwrap_or_default());
444 edge_ids.extend(in_guard.get(¤t).cloned().unwrap_or_default());
445 }
446 }
447
448 let mut dedup = HashSet::new();
449 for edge_id in edge_ids {
450 if !dedup.insert(edge_id.clone()) {
451 continue;
452 }
453 let edge = match edge_guard.get(&edge_id) {
454 Some(e) => e.clone(),
455 None => continue,
456 };
457 if let Some(un) = user_name {
458 if Self::owner_from_metadata(&edge.metadata) != un {
459 continue;
460 }
461 }
462 if let Some(rel) = relation {
463 if edge.relation != rel {
464 continue;
465 }
466 }
467
468 let next = match direction {
469 GraphDirection::Outbound => {
470 if edge.from == current {
471 Some(edge.to.clone())
472 } else {
473 None
474 }
475 }
476 GraphDirection::Inbound => {
477 if edge.to == current {
478 Some(edge.from.clone())
479 } else {
480 None
481 }
482 }
483 GraphDirection::Both => {
484 if edge.from == current {
485 Some(edge.to.clone())
486 } else if edge.to == current {
487 Some(edge.from.clone())
488 } else {
489 None
490 }
491 }
492 };
493 let Some(next_node_id) = next else { continue };
494 let Some(next_node) = node_guard.get(&next_node_id) else {
495 continue;
496 };
497 if let Some(un) = user_name {
498 if Self::owner_from_metadata(&next_node.metadata) != un {
499 continue;
500 }
501 }
502 if !include_deleted && Self::is_tombstone(&next_node.metadata) {
503 continue;
504 }
505 transitions.push((next_node_id, edge));
506 }
507
508 transitions.sort_by(|a, b| a.1.id.cmp(&b.1.id).then_with(|| a.0.cmp(&b.0)));
509
510 for (next_node_id, edge) in transitions {
511 if visited.contains(&next_node_id) {
512 continue;
513 }
514 visited.insert(next_node_id.clone());
515 prev.insert(next_node_id.clone(), (current.clone(), edge));
516 if next_node_id == target_id {
517 let mut rev_nodes = vec![target_id.to_string()];
518 let mut rev_edges: Vec<MemoryEdge> = Vec::new();
519 let mut cursor = target_id.to_string();
520 while cursor != source_id {
521 let (p, e) = prev.get(&cursor).ok_or_else(|| {
522 GraphStoreError::Other("path reconstruction failed".to_string())
523 })?;
524 rev_edges.push(e.clone());
525 rev_nodes.push(p.clone());
526 cursor = p.clone();
527 }
528 rev_nodes.reverse();
529 rev_edges.reverse();
530 return Ok(Some(GraphPath {
531 node_ids: rev_nodes,
532 edges: rev_edges,
533 }));
534 }
535 queue.push_back((next_node_id, depth + 1));
536 }
537 }
538
539 Ok(None)
540 }
541
542 async fn find_paths(
543 &self,
544 source_id: &str,
545 target_id: &str,
546 relation: Option<&str>,
547 direction: GraphDirection,
548 max_depth: usize,
549 top_k: usize,
550 include_deleted: bool,
551 user_name: Option<&str>,
552 ) -> Result<Vec<GraphPath>, GraphStoreError> {
553 if top_k == 0 {
554 return Ok(Vec::new());
555 }
556 if max_depth == 0 && source_id != target_id {
557 return Ok(Vec::new());
558 }
559
560 {
561 let nodes = self.nodes.read().await;
562 let source = nodes.get(source_id).ok_or_else(|| {
563 GraphStoreError::Other(format!("node not found: {}", source_id))
564 })?;
565 let target = nodes.get(target_id).ok_or_else(|| {
566 GraphStoreError::Other(format!("node not found: {}", target_id))
567 })?;
568 if let Some(un) = user_name {
569 if Self::owner_from_metadata(&source.metadata) != un
570 || Self::owner_from_metadata(&target.metadata) != un
571 {
572 return Err(GraphStoreError::Other(format!(
573 "node not found or access denied: {} -> {}",
574 source_id, target_id
575 )));
576 }
577 }
578 if !include_deleted
579 && (Self::is_tombstone(&source.metadata) || Self::is_tombstone(&target.metadata))
580 {
581 return Ok(Vec::new());
582 }
583 }
584
585 if source_id == target_id {
586 return Ok(vec![GraphPath {
587 node_ids: vec![source_id.to_string()],
588 edges: Vec::new(),
589 }]);
590 }
591
592 #[derive(Clone)]
593 struct PathState {
594 current: String,
595 node_ids: Vec<String>,
596 edges: Vec<MemoryEdge>,
597 visited: HashSet<String>,
598 }
599
600 let edge_guard = self.edges.read().await;
601 let node_guard = self.nodes.read().await;
602 let out_guard = self.out_index.read().await;
603 let in_guard = self.in_index.read().await;
604
605 let mut queue: VecDeque<PathState> = VecDeque::new();
606 let mut start_visited = HashSet::new();
607 start_visited.insert(source_id.to_string());
608 queue.push_back(PathState {
609 current: source_id.to_string(),
610 node_ids: vec![source_id.to_string()],
611 edges: Vec::new(),
612 visited: start_visited,
613 });
614
615 let mut results: Vec<GraphPath> = Vec::new();
616 while let Some(state) = queue.pop_front() {
617 if results.len() >= top_k {
618 break;
619 }
620 if state.current == target_id {
621 results.push(GraphPath {
622 node_ids: state.node_ids.clone(),
623 edges: state.edges.clone(),
624 });
625 continue;
626 }
627 if state.edges.len() >= max_depth {
628 continue;
629 }
630
631 let mut edge_ids: Vec<String> = Vec::new();
632 match direction {
633 GraphDirection::Outbound => {
634 edge_ids.extend(out_guard.get(&state.current).cloned().unwrap_or_default());
635 }
636 GraphDirection::Inbound => {
637 edge_ids.extend(in_guard.get(&state.current).cloned().unwrap_or_default());
638 }
639 GraphDirection::Both => {
640 edge_ids.extend(out_guard.get(&state.current).cloned().unwrap_or_default());
641 edge_ids.extend(in_guard.get(&state.current).cloned().unwrap_or_default());
642 }
643 }
644
645 let mut dedup = HashSet::new();
646 let mut transitions: Vec<(String, MemoryEdge)> = Vec::new();
647 for edge_id in edge_ids {
648 if !dedup.insert(edge_id.clone()) {
649 continue;
650 }
651 let edge = match edge_guard.get(&edge_id) {
652 Some(e) => e.clone(),
653 None => continue,
654 };
655 if let Some(un) = user_name {
656 if Self::owner_from_metadata(&edge.metadata) != un {
657 continue;
658 }
659 }
660 if let Some(rel) = relation {
661 if edge.relation != rel {
662 continue;
663 }
664 }
665
666 let next = match direction {
667 GraphDirection::Outbound => {
668 if edge.from == state.current {
669 Some(edge.to.clone())
670 } else {
671 None
672 }
673 }
674 GraphDirection::Inbound => {
675 if edge.to == state.current {
676 Some(edge.from.clone())
677 } else {
678 None
679 }
680 }
681 GraphDirection::Both => {
682 if edge.from == state.current {
683 Some(edge.to.clone())
684 } else if edge.to == state.current {
685 Some(edge.from.clone())
686 } else {
687 None
688 }
689 }
690 };
691
692 let Some(next_node_id) = next else { continue };
693 if state.visited.contains(&next_node_id) {
694 continue;
695 }
696 let Some(next_node) = node_guard.get(&next_node_id) else {
697 continue;
698 };
699 if let Some(un) = user_name {
700 if Self::owner_from_metadata(&next_node.metadata) != un {
701 continue;
702 }
703 }
704 if !include_deleted && Self::is_tombstone(&next_node.metadata) {
705 continue;
706 }
707 transitions.push((next_node_id, edge));
708 }
709 transitions.sort_by(|a, b| a.1.id.cmp(&b.1.id).then_with(|| a.0.cmp(&b.0)));
710
711 for (next_node_id, edge) in transitions {
712 let mut next_state = state.clone();
713 next_state.current = next_node_id.clone();
714 next_state.node_ids.push(next_node_id.clone());
715 next_state.edges.push(edge);
716 next_state.visited.insert(next_node_id);
717 queue.push_back(next_state);
718 }
719 }
720
721 Ok(results)
722 }
723
724 async fn search_by_embedding(
725 &self,
726 vector: &[f32],
727 top_k: usize,
728 user_name: Option<&str>,
729 ) -> Result<Vec<VecSearchHit>, GraphStoreError> {
730 let guard = self.nodes.read().await;
731 let un = user_name.unwrap_or("");
732 let mut candidates: Vec<(String, f64)> = Vec::new();
733 for node in guard.values() {
734 if !un.is_empty() {
735 let node_user = Self::owner_from_metadata(&node.metadata);
736 if node_user != un {
737 continue;
738 }
739 }
740 let emb = match &node.embedding {
741 Some(e) => e,
742 None => continue,
743 };
744 if emb.len() != vector.len() {
745 continue;
746 }
747 let score = cosine_similarity(vector, emb);
748 candidates.push((node.id.clone(), score));
749 }
750 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
751 let hits = candidates
752 .into_iter()
753 .take(top_k)
754 .map(|(id, score)| VecSearchHit { id, score })
755 .collect();
756 Ok(hits)
757 }
758
759 async fn get_all_memory_items(
760 &self,
761 scope: &str,
762 user_name: &str,
763 include_embedding: bool,
764 ) -> Result<Vec<MemoryNode>, GraphStoreError> {
765 let ids = {
766 let idx = self.scope_index.read().await;
767 idx.get(user_name)
768 .and_then(|m| m.get(scope))
769 .cloned()
770 .unwrap_or_default()
771 };
772 let mut nodes = self.get_nodes(&ids, include_embedding).await?;
773 nodes.sort_by(|a, b| a.id.cmp(&b.id));
774 Ok(nodes)
775 }
776
777 async fn update_node(
778 &self,
779 id: &str,
780 fields: &HashMap<String, serde_json::Value>,
781 user_name: Option<&str>,
782 ) -> Result<(), GraphStoreError> {
783 let mut guard = self.nodes.write().await;
784 let node = guard
785 .get_mut(id)
786 .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
787 if let Some(un) = user_name {
788 let node_owner = Self::owner_from_metadata(&node.metadata);
789 if node_owner != un {
790 return Err(GraphStoreError::Other(format!(
791 "node not found or access denied: {}",
792 id
793 )));
794 }
795 }
796 for (k, v) in fields {
797 if k == "memory" {
798 node.memory = v.as_str().unwrap_or("").to_string();
799 } else {
800 node.metadata.insert(k.clone(), v.clone());
801 }
802 }
803 Ok(())
804 }
805
806 async fn delete_node(&self, id: &str, user_name: Option<&str>) -> Result<(), GraphStoreError> {
807 {
808 let nodes = self.nodes.read().await;
809 let node = nodes
810 .get(id)
811 .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
812 if let Some(un) = user_name {
813 let node_owner = Self::owner_from_metadata(&node.metadata);
814 if node_owner != un {
815 return Err(GraphStoreError::Other(format!(
816 "node not found or access denied: {}",
817 id
818 )));
819 }
820 }
821 }
822 {
823 let mut nodes = self.nodes.write().await;
824 nodes
825 .remove(id)
826 .ok_or_else(|| GraphStoreError::Other(format!("node not found: {}", id)))?;
827 }
828 {
829 let mut idx = self.scope_index.write().await;
830 for scope_map in idx.values_mut() {
831 for list in scope_map.values_mut() {
832 list.retain(|x| x != id);
833 }
834 }
835 }
836 self.delete_edges_by_node(id, user_name).await?;
837 Ok(())
838 }
839
840 async fn delete_edges_by_node(
841 &self,
842 id: &str,
843 user_name: Option<&str>,
844 ) -> Result<usize, GraphStoreError> {
845 let mut edge_ids: HashSet<String> = HashSet::new();
846 {
847 let out_guard = self.out_index.read().await;
848 edge_ids.extend(out_guard.get(id).cloned().unwrap_or_default());
849 }
850 {
851 let in_guard = self.in_index.read().await;
852 edge_ids.extend(in_guard.get(id).cloned().unwrap_or_default());
853 }
854
855 if edge_ids.is_empty() {
856 return Ok(0);
857 }
858
859 let mut edge_guard = self.edges.write().await;
860 let mut out_guard = self.out_index.write().await;
861 let mut in_guard = self.in_index.write().await;
862
863 if let Some(un) = user_name {
864 for edge_id in &edge_ids {
865 if let Some(edge) = edge_guard.get(edge_id) {
866 if Self::owner_from_metadata(&edge.metadata) != un {
867 return Err(GraphStoreError::Other(format!(
868 "edge not found or access denied: {}",
869 edge_id
870 )));
871 }
872 }
873 }
874 }
875
876 let mut deleted = 0usize;
877 for edge_id in edge_ids {
878 if let Some(edge) = edge_guard.remove(&edge_id) {
879 Self::remove_edge_indexes(&edge, &mut out_guard, &mut in_guard);
880 deleted += 1;
881 }
882 }
883 Ok(deleted)
884 }
885}
886
887fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
888 if a.len() != b.len() || a.is_empty() {
889 return 0.0;
890 }
891 let dot: f64 = a
892 .iter()
893 .zip(b.iter())
894 .map(|(x, y)| (*x as f64) * (*y as f64))
895 .sum();
896 let na: f64 = a.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
897 let nb: f64 = b.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
898 if na == 0.0 || nb == 0.0 {
899 return 0.0;
900 }
901 dot / (na * nb)
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907
908 #[tokio::test]
909 async fn neighbors_are_deterministic_and_limited() {
910 let store = InMemoryGraphStore::new();
911
912 let mut meta = HashMap::new();
913 meta.insert(
914 "scope".to_string(),
915 serde_json::Value::String("LongTermMemory".to_string()),
916 );
917
918 store.add_node("n0", "root", &meta, Some("u1")).await.unwrap();
919 store.add_node("n1", "node1", &meta, Some("u1")).await.unwrap();
920 store.add_node("n2", "node2", &meta, Some("u1")).await.unwrap();
921
922 store
923 .add_edges_batch(
924 &[
925 MemoryEdge {
926 id: "e2".to_string(),
927 from: "n0".to_string(),
928 to: "n2".to_string(),
929 relation: "related_to".to_string(),
930 metadata: HashMap::new(),
931 },
932 MemoryEdge {
933 id: "e1".to_string(),
934 from: "n0".to_string(),
935 to: "n1".to_string(),
936 relation: "related_to".to_string(),
937 metadata: HashMap::new(),
938 },
939 ],
940 Some("u1"),
941 )
942 .await
943 .unwrap();
944
945 let all = store
946 .get_neighbors(
947 "n0",
948 Some("related_to"),
949 GraphDirection::Outbound,
950 10,
951 false,
952 Some("u1"),
953 )
954 .await
955 .unwrap();
956 assert_eq!(all.len(), 2);
957 assert_eq!(all[0].edge.id, "e1");
958 assert_eq!(all[1].edge.id, "e2");
959
960 let limited = store
961 .get_neighbors(
962 "n0",
963 Some("related_to"),
964 GraphDirection::Outbound,
965 1,
966 false,
967 Some("u1"),
968 )
969 .await
970 .unwrap();
971 assert_eq!(limited.len(), 1);
972 assert_eq!(limited[0].edge.id, "e1");
973 }
974
975 #[tokio::test]
976 async fn shortest_path_finds_min_hops() {
977 let store = InMemoryGraphStore::new();
978
979 let mut meta = HashMap::new();
980 meta.insert(
981 "scope".to_string(),
982 serde_json::Value::String("LongTermMemory".to_string()),
983 );
984
985 store.add_node("a", "A", &meta, Some("u1")).await.unwrap();
986 store.add_node("b", "B", &meta, Some("u1")).await.unwrap();
987 store.add_node("c", "C", &meta, Some("u1")).await.unwrap();
988 store.add_node("d", "D", &meta, Some("u1")).await.unwrap();
989
990 store
991 .add_edges_batch(
992 &[
993 MemoryEdge {
994 id: "e_ab".to_string(),
995 from: "a".to_string(),
996 to: "b".to_string(),
997 relation: "related_to".to_string(),
998 metadata: HashMap::new(),
999 },
1000 MemoryEdge {
1001 id: "e_bc".to_string(),
1002 from: "b".to_string(),
1003 to: "c".to_string(),
1004 relation: "related_to".to_string(),
1005 metadata: HashMap::new(),
1006 },
1007 MemoryEdge {
1008 id: "e_ad".to_string(),
1009 from: "a".to_string(),
1010 to: "d".to_string(),
1011 relation: "related_to".to_string(),
1012 metadata: HashMap::new(),
1013 },
1014 MemoryEdge {
1015 id: "e_dc".to_string(),
1016 from: "d".to_string(),
1017 to: "c".to_string(),
1018 relation: "related_to".to_string(),
1019 metadata: HashMap::new(),
1020 },
1021 ],
1022 Some("u1"),
1023 )
1024 .await
1025 .unwrap();
1026
1027 let path = store
1028 .shortest_path(
1029 "a",
1030 "c",
1031 Some("related_to"),
1032 GraphDirection::Outbound,
1033 3,
1034 false,
1035 Some("u1"),
1036 )
1037 .await
1038 .unwrap()
1039 .unwrap();
1040 assert_eq!(path.node_ids.first().map(String::as_str), Some("a"));
1041 assert_eq!(path.node_ids.last().map(String::as_str), Some("c"));
1042 assert_eq!(path.edges.len(), 2);
1043 }
1044
1045 #[tokio::test]
1046 async fn find_paths_returns_top_k_shortest() {
1047 let store = InMemoryGraphStore::new();
1048
1049 let mut meta = HashMap::new();
1050 meta.insert(
1051 "scope".to_string(),
1052 serde_json::Value::String("LongTermMemory".to_string()),
1053 );
1054
1055 for id in ["s", "a", "b", "t"] {
1056 store.add_node(id, id, &meta, Some("u1")).await.unwrap();
1057 }
1058 store
1059 .add_edges_batch(
1060 &[
1061 MemoryEdge {
1062 id: "e_sa".to_string(),
1063 from: "s".to_string(),
1064 to: "a".to_string(),
1065 relation: "r".to_string(),
1066 metadata: HashMap::new(),
1067 },
1068 MemoryEdge {
1069 id: "e_at".to_string(),
1070 from: "a".to_string(),
1071 to: "t".to_string(),
1072 relation: "r".to_string(),
1073 metadata: HashMap::new(),
1074 },
1075 MemoryEdge {
1076 id: "e_sb".to_string(),
1077 from: "s".to_string(),
1078 to: "b".to_string(),
1079 relation: "r".to_string(),
1080 metadata: HashMap::new(),
1081 },
1082 MemoryEdge {
1083 id: "e_bt".to_string(),
1084 from: "b".to_string(),
1085 to: "t".to_string(),
1086 relation: "r".to_string(),
1087 metadata: HashMap::new(),
1088 },
1089 ],
1090 Some("u1"),
1091 )
1092 .await
1093 .unwrap();
1094
1095 let paths = store
1096 .find_paths("s", "t", Some("r"), GraphDirection::Outbound, 3, 2, false, Some("u1"))
1097 .await
1098 .unwrap();
1099 assert_eq!(paths.len(), 2);
1100 assert_eq!(paths[0].edges.len(), 2);
1101 assert_eq!(paths[1].edges.len(), 2);
1102 }
1103}