1use crate::error::AgentRuntimeError;
17use crate::util::recover_lock;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
21use std::sync::{Arc, Mutex};
22
23#[derive(Debug, Clone, Copy, PartialEq)]
28struct OrdF32(f32);
29
30impl Eq for OrdF32 {}
31
32impl PartialOrd for OrdF32 {
33 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
34 Some(self.cmp(other))
35 }
36}
37
38impl Ord for OrdF32 {
39 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
40 self.0
41 .partial_cmp(&other.0)
42 .unwrap_or(std::cmp::Ordering::Greater)
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
50pub struct EntityId(pub String);
51
52impl EntityId {
53 pub fn new(id: impl Into<String>) -> Self {
55 Self(id.into())
56 }
57
58 pub fn as_str(&self) -> &str {
60 &self.0
61 }
62}
63
64impl AsRef<str> for EntityId {
65 fn as_ref(&self) -> &str {
66 &self.0
67 }
68}
69
70impl std::fmt::Display for EntityId {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(f, "{}", self.0)
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct Entity {
81 pub id: EntityId,
83 pub label: String,
85 pub properties: HashMap<String, Value>,
87}
88
89impl Entity {
90 pub fn new(id: impl Into<String>, label: impl Into<String>) -> Self {
92 Self {
93 id: EntityId::new(id),
94 label: label.into(),
95 properties: HashMap::new(),
96 }
97 }
98
99 pub fn with_properties(
101 id: impl Into<String>,
102 label: impl Into<String>,
103 properties: HashMap<String, Value>,
104 ) -> Self {
105 Self {
106 id: EntityId::new(id),
107 label: label.into(),
108 properties,
109 }
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct Relationship {
118 pub from: EntityId,
120 pub to: EntityId,
122 pub kind: String,
124 pub weight: f32,
126}
127
128impl Relationship {
129 pub fn new(
131 from: impl Into<String>,
132 to: impl Into<String>,
133 kind: impl Into<String>,
134 weight: f32,
135 ) -> Self {
136 Self {
137 from: EntityId::new(from),
138 to: EntityId::new(to),
139 kind: kind.into(),
140 weight,
141 }
142 }
143}
144
145#[derive(Debug, thiserror::Error)]
149pub enum MemGraphError {
150 #[error("Entity '{0}' not found")]
152 EntityNotFound(String),
153
154 #[error("Relationship '{kind}' from '{from}' to '{to}' already exists")]
156 DuplicateRelationship {
157 from: String,
159 to: String,
161 kind: String,
163 },
164
165 #[error("Graph internal error: {0}")]
167 Internal(String),
168}
169
170impl From<MemGraphError> for AgentRuntimeError {
171 fn from(e: MemGraphError) -> Self {
172 AgentRuntimeError::Graph(e.to_string())
173 }
174}
175
176#[derive(Debug, Clone)]
187pub struct GraphStore {
188 inner: Arc<Mutex<GraphInner>>,
189}
190
191#[derive(Debug)]
192struct GraphInner {
193 entities: HashMap<EntityId, Entity>,
194 relationships: Vec<Relationship>,
195 cycle_cache: Option<bool>,
197}
198
199impl GraphStore {
200 pub fn new() -> Self {
202 Self {
203 inner: Arc::new(Mutex::new(GraphInner {
204 entities: HashMap::new(),
205 relationships: Vec::new(),
206 cycle_cache: None,
207 })),
208 }
209 }
210
211 pub fn add_entity(&self, entity: Entity) -> Result<(), AgentRuntimeError> {
215 let mut inner = recover_lock(self.inner.lock(), "add_entity");
216 inner.cycle_cache = None;
217 inner.entities.insert(entity.id.clone(), entity);
218 Ok(())
219 }
220
221 pub fn get_entity(&self, id: &EntityId) -> Result<Entity, AgentRuntimeError> {
223 let inner = recover_lock(self.inner.lock(), "get_entity");
224 inner
225 .entities
226 .get(id)
227 .cloned()
228 .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))
229 }
230
231 pub fn add_relationship(&self, rel: Relationship) -> Result<(), AgentRuntimeError> {
235 let mut inner = recover_lock(self.inner.lock(), "add_relationship");
236
237 if !inner.entities.contains_key(&rel.from) {
238 return Err(AgentRuntimeError::Graph(format!(
239 "source entity '{}' not found",
240 rel.from.0
241 )));
242 }
243 if !inner.entities.contains_key(&rel.to) {
244 return Err(AgentRuntimeError::Graph(format!(
245 "target entity '{}' not found",
246 rel.to.0
247 )));
248 }
249
250 let duplicate = inner
254 .relationships
255 .iter()
256 .any(|r| r.from == rel.from && r.to == rel.to && r.kind == rel.kind);
257 if duplicate {
258 return Err(AgentRuntimeError::Graph(
259 MemGraphError::DuplicateRelationship {
260 from: rel.from.0.clone(),
261 to: rel.to.0.clone(),
262 kind: rel.kind.clone(),
263 }
264 .to_string(),
265 ));
266 }
267
268 inner.cycle_cache = None;
269 inner.relationships.push(rel);
270 Ok(())
271 }
272
273 pub fn remove_entity(&self, id: &EntityId) -> Result<(), AgentRuntimeError> {
275 let mut inner = recover_lock(self.inner.lock(), "remove_entity");
276
277 if inner.entities.remove(id).is_none() {
278 return Err(AgentRuntimeError::Graph(format!(
279 "entity '{}' not found",
280 id.0
281 )));
282 }
283 inner.cycle_cache = None;
284 inner.relationships.retain(|r| &r.from != id && &r.to != id);
285 Ok(())
286 }
287
288 fn neighbours(relationships: &[Relationship], id: &EntityId) -> Vec<EntityId> {
290 relationships
291 .iter()
292 .filter(|r| &r.from == id)
293 .map(|r| r.to.clone())
294 .collect()
295 }
296
297 #[tracing::instrument(skip(self))]
301 pub fn bfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
302 let inner = recover_lock(self.inner.lock(), "bfs");
303
304 if !inner.entities.contains_key(start) {
305 return Err(AgentRuntimeError::Graph(format!(
306 "start entity '{}' not found",
307 start.0
308 )));
309 }
310
311 let mut visited: HashSet<EntityId> = HashSet::new();
312 let mut queue: VecDeque<EntityId> = VecDeque::new();
313 let mut result: Vec<EntityId> = Vec::new();
314
315 visited.insert(start.clone());
316 queue.push_back(start.clone());
317
318 while let Some(current) = queue.pop_front() {
319 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
320 for neighbour in neighbours {
321 if visited.insert(neighbour.clone()) {
322 result.push(neighbour.clone());
323 queue.push_back(neighbour);
324 }
325 }
326 }
327
328 tracing::debug!("BFS visited {} nodes", result.len());
329 Ok(result)
330 }
331
332 #[tracing::instrument(skip(self))]
336 pub fn dfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
337 let inner = recover_lock(self.inner.lock(), "dfs");
338
339 if !inner.entities.contains_key(start) {
340 return Err(AgentRuntimeError::Graph(format!(
341 "start entity '{}' not found",
342 start.0
343 )));
344 }
345
346 let mut visited: HashSet<EntityId> = HashSet::new();
347 let mut stack: Vec<EntityId> = Vec::new();
348 let mut result: Vec<EntityId> = Vec::new();
349
350 visited.insert(start.clone());
351 stack.push(start.clone());
352
353 while let Some(current) = stack.pop() {
354 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
355 for neighbour in neighbours {
356 if visited.insert(neighbour.clone()) {
357 result.push(neighbour.clone());
358 stack.push(neighbour);
359 }
360 }
361 }
362
363 tracing::debug!("DFS visited {} nodes", result.len());
364 Ok(result)
365 }
366
367 #[tracing::instrument(skip(self))]
373 pub fn shortest_path(
374 &self,
375 from: &EntityId,
376 to: &EntityId,
377 ) -> Result<Option<Vec<EntityId>>, AgentRuntimeError> {
378 let inner = recover_lock(self.inner.lock(), "shortest_path");
379
380 if !inner.entities.contains_key(from) {
381 return Err(AgentRuntimeError::Graph(format!(
382 "source entity '{}' not found",
383 from.0
384 )));
385 }
386 if !inner.entities.contains_key(to) {
387 return Err(AgentRuntimeError::Graph(format!(
388 "target entity '{}' not found",
389 to.0
390 )));
391 }
392
393 if from == to {
394 return Ok(Some(vec![from.clone()]));
395 }
396
397 let mut visited: HashSet<EntityId> = HashSet::new();
399 let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
400 let mut queue: VecDeque<EntityId> = VecDeque::new();
401
402 visited.insert(from.clone());
403 queue.push_back(from.clone());
404
405 while let Some(current) = queue.pop_front() {
406 for neighbour in Self::neighbours(&inner.relationships, ¤t) {
407 if &neighbour == to {
408 let mut path = vec![neighbour, current.clone()];
410 let mut node = current;
411 while let Some(p) = prev.get(&node) {
412 path.push(p.clone());
413 node = p.clone();
414 }
415 path.reverse();
416 return Ok(Some(path));
417 }
418 if visited.insert(neighbour.clone()) {
419 prev.insert(neighbour.clone(), current.clone());
420 queue.push_back(neighbour);
421 }
422 }
423 }
424
425 Ok(None)
426 }
427
428 pub fn shortest_path_weighted(
438 &self,
439 from: &EntityId,
440 to: &EntityId,
441 ) -> Result<Option<(Vec<EntityId>, f32)>, AgentRuntimeError> {
442 let inner = recover_lock(self.inner.lock(), "shortest_path_weighted");
443
444 if !inner.entities.contains_key(from) {
445 return Err(AgentRuntimeError::Graph(format!(
446 "source entity '{}' not found",
447 from.0
448 )));
449 }
450 if !inner.entities.contains_key(to) {
451 return Err(AgentRuntimeError::Graph(format!(
452 "target entity '{}' not found",
453 to.0
454 )));
455 }
456
457 for rel in &inner.relationships {
459 if rel.weight < 0.0 {
460 return Err(AgentRuntimeError::Graph(format!(
461 "negative weight {:.4} on edge '{}' -> '{}'",
462 rel.weight, rel.from.0, rel.to.0
463 )));
464 }
465 }
466
467 if from == to {
468 return Ok(Some((vec![from.clone()], 0.0)));
469 }
470
471 let mut dist: HashMap<EntityId, f32> = HashMap::new();
474 let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
475 let mut heap: BinaryHeap<(OrdF32, EntityId)> = BinaryHeap::new();
477
478 dist.insert(from.clone(), 0.0);
479 heap.push((OrdF32(-0.0), from.clone()));
480
481 while let Some((OrdF32(neg_cost), current)) = heap.pop() {
482 let cost = -neg_cost;
483
484 if let Some(&best) = dist.get(¤t) {
486 if cost > best {
487 continue;
488 }
489 }
490
491 if ¤t == to {
492 let mut path = vec![to.clone()];
494 let mut node = to.clone();
495 while let Some(p) = prev.get(&node) {
496 path.push(p.clone());
497 node = p.clone();
498 }
499 path.reverse();
500 return Ok(Some((path, cost)));
501 }
502
503 for rel in inner.relationships.iter().filter(|r| &r.from == ¤t) {
504 let next_cost = cost + rel.weight;
505 let entry = dist.entry(rel.to.clone()).or_insert(f32::INFINITY);
506 if next_cost < *entry {
507 *entry = next_cost;
508 prev.insert(rel.to.clone(), current.clone());
509 heap.push((OrdF32(-next_cost), rel.to.clone()));
510 }
511 }
512 }
513
514 Ok(None)
515 }
516
517 fn bfs_into_set(inner: &GraphInner, start: &EntityId) -> HashSet<EntityId> {
522 let mut visited: HashSet<EntityId> = HashSet::new();
523 let mut queue: VecDeque<EntityId> = VecDeque::new();
524 visited.insert(start.clone());
525 queue.push_back(start.clone());
526 while let Some(current) = queue.pop_front() {
527 for neighbour in Self::neighbours(&inner.relationships, ¤t) {
528 if visited.insert(neighbour.clone()) {
529 queue.push_back(neighbour);
530 }
531 }
532 }
533 visited
534 }
535
536 pub fn transitive_closure(
545 &self,
546 start: &EntityId,
547 ) -> Result<HashSet<EntityId>, AgentRuntimeError> {
548 let inner = recover_lock(self.inner.lock(), "transitive_closure");
549 if !inner.entities.contains_key(start) {
550 return Err(AgentRuntimeError::Graph(format!(
551 "start entity '{}' not found",
552 start.0
553 )));
554 }
555 Ok(Self::bfs_into_set(&inner, start))
556 }
557
558 pub fn entity_count(&self) -> Result<usize, AgentRuntimeError> {
560 let inner = recover_lock(self.inner.lock(), "entity_count");
561 Ok(inner.entities.len())
562 }
563
564 pub fn relationship_count(&self) -> Result<usize, AgentRuntimeError> {
566 let inner = recover_lock(self.inner.lock(), "relationship_count");
567 Ok(inner.relationships.len())
568 }
569
570 pub fn degree_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
573 let inner = recover_lock(self.inner.lock(), "degree_centrality");
574 let n = inner.entities.len();
575
576 let mut out_degree: HashMap<EntityId, usize> = HashMap::new();
577 let mut in_degree: HashMap<EntityId, usize> = HashMap::new();
578
579 for id in inner.entities.keys() {
580 out_degree.insert(id.clone(), 0);
581 in_degree.insert(id.clone(), 0);
582 }
583
584 for rel in &inner.relationships {
585 *out_degree.entry(rel.from.clone()).or_insert(0) += 1;
586 *in_degree.entry(rel.to.clone()).or_insert(0) += 1;
587 }
588
589 let denom = if n <= 1 { 1.0 } else { (n - 1) as f32 };
590 let mut result = HashMap::new();
591 for id in inner.entities.keys() {
592 let od = *out_degree.get(id).unwrap_or(&0);
593 let id_ = *in_degree.get(id).unwrap_or(&0);
594 let centrality = if n <= 1 {
595 0.0
596 } else {
597 (od + id_) as f32 / denom
598 };
599 result.insert(id.clone(), centrality);
600 }
601
602 Ok(result)
603 }
604
605 pub fn betweenness_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
611 let inner = recover_lock(self.inner.lock(), "betweenness_centrality");
612 let n = inner.entities.len();
613 let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
614
615 let mut centrality: HashMap<EntityId, f32> =
616 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
617
618 for source in &nodes {
619 let mut stack: Vec<EntityId> = Vec::new();
621 let mut predecessors: HashMap<EntityId, Vec<EntityId>> =
622 nodes.iter().map(|id| (id.clone(), vec![])).collect();
623 let mut sigma: HashMap<EntityId, f32> =
624 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
625 let mut dist: HashMap<EntityId, i64> =
626 nodes.iter().map(|id| (id.clone(), -1i64)).collect();
627
628 *sigma.entry(source.clone()).or_insert(0.0) = 1.0;
629 *dist.entry(source.clone()).or_insert(-1) = 0;
630
631 let mut queue: VecDeque<EntityId> = VecDeque::new();
632 queue.push_back(source.clone());
633
634 while let Some(v) = queue.pop_front() {
635 stack.push(v.clone());
636 let d_v = *dist.get(&v).unwrap_or(&0);
637 let sigma_v = *sigma.get(&v).unwrap_or(&0.0);
638 for rel in inner.relationships.iter().filter(|r| &r.from == &v) {
639 let w = &rel.to;
640 let d_w = dist.get(w).copied().unwrap_or(-1);
641 if d_w < 0 {
642 queue.push_back(w.clone());
643 *dist.entry(w.clone()).or_insert(-1) = d_v + 1;
644 }
645 if dist.get(w).copied().unwrap_or(-1) == d_v + 1 {
646 *sigma.entry(w.clone()).or_insert(0.0) += sigma_v;
647 predecessors.entry(w.clone()).or_default().push(v.clone());
648 }
649 }
650 }
651
652 let mut delta: HashMap<EntityId, f32> =
653 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
654
655 while let Some(w) = stack.pop() {
656 let delta_w = *delta.get(&w).unwrap_or(&0.0);
657 let sigma_w = *sigma.get(&w).unwrap_or(&1.0);
658 for v in predecessors
660 .get(&w)
661 .map(|ps| ps.as_slice())
662 .unwrap_or_default()
663 {
664 let sigma_v = *sigma.get(v).unwrap_or(&1.0);
665 let contribution = (sigma_v / sigma_w) * (1.0 + delta_w);
666 *delta.entry(v.clone()).or_insert(0.0) += contribution;
667 }
668 if &w != source {
669 *centrality.entry(w.clone()).or_insert(0.0) += delta_w;
670 }
671 }
672 }
673
674 if n > 2 {
676 let norm = 2.0 / (((n - 1) * (n - 2)) as f32);
677 for v in centrality.values_mut() {
678 *v *= norm;
679 }
680 } else {
681 for v in centrality.values_mut() {
682 *v = 0.0;
683 }
684 }
685
686 Ok(centrality)
687 }
688
689 pub fn label_propagation_communities(
694 &self,
695 max_iterations: usize,
696 ) -> Result<HashMap<EntityId, usize>, AgentRuntimeError> {
697 let inner = recover_lock(self.inner.lock(), "label_propagation_communities");
698 let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
699
700 let mut labels: HashMap<EntityId, usize> = nodes
702 .iter()
703 .enumerate()
704 .map(|(i, id)| (id.clone(), i))
705 .collect();
706
707 for _ in 0..max_iterations {
708 let mut changed = false;
709 for node in &nodes {
711 let neighbour_labels: Vec<usize> = inner
713 .relationships
714 .iter()
715 .filter(|r| &r.from == node || &r.to == node)
716 .map(|r| {
717 if &r.from == node {
718 labels.get(&r.to).copied().unwrap_or(0)
719 } else {
720 labels.get(&r.from).copied().unwrap_or(0)
721 }
722 })
723 .collect();
724
725 if neighbour_labels.is_empty() {
726 continue;
727 }
728
729 let mut freq: HashMap<usize, usize> = HashMap::new();
731 for &lbl in &neighbour_labels {
732 *freq.entry(lbl).or_insert(0) += 1;
733 }
734 let best = freq
735 .into_iter()
736 .max_by_key(|&(_, count)| count)
737 .map(|(lbl, _)| lbl);
738
739 if let Some(new_label) = best {
740 let current = labels.entry(node.clone()).or_insert(0);
741 if *current != new_label {
742 *current = new_label;
743 changed = true;
744 }
745 }
746 }
747
748 if !changed {
749 break;
750 }
751 }
752
753 Ok(labels)
754 }
755
756 pub fn detect_cycles(&self) -> Result<bool, AgentRuntimeError> {
766 let mut inner = recover_lock(self.inner.lock(), "detect_cycles");
767
768 if let Some(cached) = inner.cycle_cache {
769 return Ok(cached);
770 }
771
772 let mut color: HashMap<&EntityId, u8> =
774 inner.entities.keys().map(|id| (id, 0u8)).collect();
775
776 let has_cycle = 'outer: {
777 for start in inner.entities.keys() {
778 if *color.get(start).unwrap_or(&0) != 0 {
779 continue;
780 }
781
782 let mut stack: Vec<(&EntityId, usize)> = vec![(start, 0)];
784 *color.entry(start).or_insert(0) = 1;
785
786 while let Some((node, idx)) = stack.last_mut() {
787 let neighbors: Vec<&EntityId> = inner
789 .relationships
790 .iter()
791 .filter(|r| &r.from == *node)
792 .map(|r| &r.to)
793 .collect();
794
795 if *idx < neighbors.len() {
796 let next = neighbors[*idx];
797 *idx += 1;
798 match color.get(next).copied().unwrap_or(0) {
799 1 => break 'outer true, 0 => {
801 *color.entry(next).or_insert(0) = 1;
802 stack.push((next, 0));
803 }
804 _ => {} }
806 } else {
807 *color.entry(*node).or_insert(0) = 2;
809 stack.pop();
810 }
811 }
812 }
813 false
814 };
815
816 inner.cycle_cache = Some(has_cycle);
817 Ok(has_cycle)
818 }
819
820 pub fn path_exists(&self, from: &str, to: &str) -> Result<bool, AgentRuntimeError> {
825 let from_id = EntityId::new(from);
826 let to_id = EntityId::new(to);
827 match self.shortest_path(&from_id, &to_id) {
828 Ok(Some(_)) => Ok(true),
829 Ok(None) => Ok(false),
830 Err(e) => Err(e),
831 }
832 }
833
834 pub fn bfs_bounded(
838 &self,
839 start: &str,
840 max_depth: usize,
841 max_nodes: usize,
842 ) -> Result<Vec<EntityId>, AgentRuntimeError> {
843 let inner = recover_lock(self.inner.lock(), "bfs_bounded");
844 let start_id = EntityId::new(start);
845 if !inner.entities.contains_key(&start_id) {
846 return Err(AgentRuntimeError::Graph(format!(
847 "start entity '{start}' not found"
848 )));
849 }
850
851 let mut visited: std::collections::HashMap<EntityId, usize> = std::collections::HashMap::new();
852 let mut queue: VecDeque<(EntityId, usize)> = VecDeque::new();
853 let mut result: Vec<EntityId> = Vec::new();
854
855 visited.insert(start_id.clone(), 0);
856 queue.push_back((start_id.clone(), 0));
857 result.push(start_id);
858
859 while let Some((current, depth)) = queue.pop_front() {
860 if result.len() >= max_nodes {
861 break;
862 }
863 if depth >= max_depth {
864 continue;
865 }
866 for neighbour in Self::neighbours(&inner.relationships, ¤t) {
867 if !visited.contains_key(&neighbour) {
868 let new_depth = depth + 1;
869 visited.insert(neighbour.clone(), new_depth);
870 result.push(neighbour.clone());
871 if result.len() >= max_nodes {
872 break;
873 }
874 queue.push_back((neighbour, new_depth));
875 }
876 }
877 }
878
879 Ok(result)
880 }
881
882 pub fn dfs_bounded(
886 &self,
887 start: &str,
888 max_depth: usize,
889 max_nodes: usize,
890 ) -> Result<Vec<EntityId>, AgentRuntimeError> {
891 let inner = recover_lock(self.inner.lock(), "dfs_bounded");
892 let start_id = EntityId::new(start);
893 if !inner.entities.contains_key(&start_id) {
894 return Err(AgentRuntimeError::Graph(format!(
895 "start entity '{start}' not found"
896 )));
897 }
898
899 let mut visited: HashSet<EntityId> = HashSet::new();
900 let mut stack: Vec<(EntityId, usize)> = Vec::new();
901 let mut result: Vec<EntityId> = Vec::new();
902
903 visited.insert(start_id.clone());
904 stack.push((start_id.clone(), 0));
905 result.push(start_id);
906
907 while let Some((current, depth)) = stack.pop() {
908 if result.len() >= max_nodes {
909 break;
910 }
911 if depth >= max_depth {
912 continue;
913 }
914 for neighbour in Self::neighbours(&inner.relationships, ¤t) {
915 if visited.insert(neighbour.clone()) {
916 result.push(neighbour.clone());
917 if result.len() >= max_nodes {
918 break;
919 }
920 stack.push((neighbour, depth + 1));
921 }
922 }
923 }
924
925 Ok(result)
926 }
927
928 pub fn subgraph(&self, node_ids: &[EntityId]) -> Result<GraphStore, AgentRuntimeError> {
931 let inner = recover_lock(self.inner.lock(), "subgraph");
932 let id_set: HashSet<&EntityId> = node_ids.iter().collect();
933
934 let new_store = GraphStore::new();
935
936 for id in node_ids {
937 let entity = inner
938 .entities
939 .get(id)
940 .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))?
941 .clone();
942 let mut new_inner = recover_lock(new_store.inner.lock(), "subgraph:add_entity");
944 new_inner.entities.insert(entity.id.clone(), entity);
945 }
946
947 for rel in inner.relationships.iter() {
948 if id_set.contains(&rel.from) && id_set.contains(&rel.to) {
949 let mut new_inner =
950 recover_lock(new_store.inner.lock(), "subgraph:add_relationship");
951 new_inner.relationships.push(rel.clone());
952 }
953 }
954
955 Ok(new_store)
956 }
957}
958
959impl Default for GraphStore {
960 fn default() -> Self {
961 Self::new()
962 }
963}
964
965#[cfg(test)]
968mod tests {
969 use super::*;
970
971 fn make_graph() -> GraphStore {
972 GraphStore::new()
973 }
974
975 fn add(g: &GraphStore, id: &str) {
976 g.add_entity(Entity::new(id, "Node")).unwrap();
977 }
978
979 fn link(g: &GraphStore, from: &str, to: &str) {
980 g.add_relationship(Relationship::new(from, to, "CONNECTS", 1.0))
981 .unwrap();
982 }
983
984 fn link_w(g: &GraphStore, from: &str, to: &str, weight: f32) {
985 g.add_relationship(Relationship::new(from, to, "CONNECTS", weight))
986 .unwrap();
987 }
988
989 #[test]
992 fn test_entity_id_equality() {
993 assert_eq!(EntityId::new("a"), EntityId::new("a"));
994 assert_ne!(EntityId::new("a"), EntityId::new("b"));
995 }
996
997 #[test]
998 fn test_entity_id_display() {
999 let id = EntityId::new("hello");
1000 assert_eq!(id.to_string(), "hello");
1001 }
1002
1003 #[test]
1006 fn test_entity_new_has_empty_properties() {
1007 let e = Entity::new("e1", "Person");
1008 assert!(e.properties.is_empty());
1009 }
1010
1011 #[test]
1012 fn test_entity_with_properties_stores_props() {
1013 let mut props = HashMap::new();
1014 props.insert("age".into(), Value::Number(42.into()));
1015 let e = Entity::with_properties("e1", "Person", props);
1016 assert!(e.properties.contains_key("age"));
1017 }
1018
1019 #[test]
1022 fn test_graph_add_entity_increments_count() {
1023 let g = make_graph();
1024 add(&g, "a");
1025 assert_eq!(g.entity_count().unwrap(), 1);
1026 }
1027
1028 #[test]
1029 fn test_graph_get_entity_returns_entity() {
1030 let g = make_graph();
1031 g.add_entity(Entity::new("e1", "Person")).unwrap();
1032 let e = g.get_entity(&EntityId::new("e1")).unwrap();
1033 assert_eq!(e.label, "Person");
1034 }
1035
1036 #[test]
1037 fn test_graph_get_entity_missing_returns_error() {
1038 let g = make_graph();
1039 assert!(g.get_entity(&EntityId::new("ghost")).is_err());
1040 }
1041
1042 #[test]
1043 fn test_graph_add_relationship_increments_count() {
1044 let g = make_graph();
1045 add(&g, "a");
1046 add(&g, "b");
1047 link(&g, "a", "b");
1048 assert_eq!(g.relationship_count().unwrap(), 1);
1049 }
1050
1051 #[test]
1052 fn test_graph_add_relationship_missing_source_fails() {
1053 let g = make_graph();
1054 add(&g, "b");
1055 let result = g.add_relationship(Relationship::new("ghost", "b", "X", 1.0));
1056 assert!(result.is_err());
1057 }
1058
1059 #[test]
1060 fn test_graph_add_relationship_missing_target_fails() {
1061 let g = make_graph();
1062 add(&g, "a");
1063 let result = g.add_relationship(Relationship::new("a", "ghost", "X", 1.0));
1064 assert!(result.is_err());
1065 }
1066
1067 #[test]
1068 fn test_graph_remove_entity_removes_relationships() {
1069 let g = make_graph();
1070 add(&g, "a");
1071 add(&g, "b");
1072 link(&g, "a", "b");
1073 g.remove_entity(&EntityId::new("a")).unwrap();
1074 assert_eq!(g.entity_count().unwrap(), 1);
1075 assert_eq!(g.relationship_count().unwrap(), 0);
1076 }
1077
1078 #[test]
1079 fn test_graph_remove_entity_missing_returns_error() {
1080 let g = make_graph();
1081 assert!(g.remove_entity(&EntityId::new("ghost")).is_err());
1082 }
1083
1084 #[test]
1087 fn test_bfs_finds_direct_neighbours() {
1088 let g = make_graph();
1089 add(&g, "a");
1090 add(&g, "b");
1091 add(&g, "c");
1092 link(&g, "a", "b");
1093 link(&g, "a", "c");
1094 let visited = g.bfs(&EntityId::new("a")).unwrap();
1095 assert_eq!(visited.len(), 2);
1096 }
1097
1098 #[test]
1099 fn test_bfs_traverses_chain() {
1100 let g = make_graph();
1101 add(&g, "a");
1102 add(&g, "b");
1103 add(&g, "c");
1104 add(&g, "d");
1105 link(&g, "a", "b");
1106 link(&g, "b", "c");
1107 link(&g, "c", "d");
1108 let visited = g.bfs(&EntityId::new("a")).unwrap();
1109 assert_eq!(visited.len(), 3);
1110 assert_eq!(visited[0], EntityId::new("b"));
1111 }
1112
1113 #[test]
1114 fn test_bfs_handles_isolated_node() {
1115 let g = make_graph();
1116 add(&g, "a");
1117 let visited = g.bfs(&EntityId::new("a")).unwrap();
1118 assert!(visited.is_empty());
1119 }
1120
1121 #[test]
1122 fn test_bfs_missing_start_returns_error() {
1123 let g = make_graph();
1124 assert!(g.bfs(&EntityId::new("ghost")).is_err());
1125 }
1126
1127 #[test]
1130 fn test_dfs_visits_all_reachable_nodes() {
1131 let g = make_graph();
1132 add(&g, "a");
1133 add(&g, "b");
1134 add(&g, "c");
1135 add(&g, "d");
1136 link(&g, "a", "b");
1137 link(&g, "a", "c");
1138 link(&g, "b", "d");
1139 let visited = g.dfs(&EntityId::new("a")).unwrap();
1140 assert_eq!(visited.len(), 3);
1141 }
1142
1143 #[test]
1144 fn test_dfs_handles_isolated_node() {
1145 let g = make_graph();
1146 add(&g, "a");
1147 let visited = g.dfs(&EntityId::new("a")).unwrap();
1148 assert!(visited.is_empty());
1149 }
1150
1151 #[test]
1152 fn test_dfs_missing_start_returns_error() {
1153 let g = make_graph();
1154 assert!(g.dfs(&EntityId::new("ghost")).is_err());
1155 }
1156
1157 #[test]
1160 fn test_shortest_path_direct_connection() {
1161 let g = make_graph();
1162 add(&g, "a");
1163 add(&g, "b");
1164 link(&g, "a", "b");
1165 let path = g
1166 .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
1167 .unwrap();
1168 assert_eq!(path, Some(vec![EntityId::new("a"), EntityId::new("b")]));
1169 }
1170
1171 #[test]
1172 fn test_shortest_path_multi_hop() {
1173 let g = make_graph();
1174 add(&g, "a");
1175 add(&g, "b");
1176 add(&g, "c");
1177 link(&g, "a", "b");
1178 link(&g, "b", "c");
1179 let path = g
1180 .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1181 .unwrap();
1182 assert_eq!(path.as_ref().map(|p| p.len()), Some(3));
1183 }
1184
1185 #[test]
1186 fn test_shortest_path_returns_none_for_disconnected() {
1187 let g = make_graph();
1188 add(&g, "a");
1189 add(&g, "b");
1190 let path = g
1191 .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
1192 .unwrap();
1193 assert_eq!(path, None);
1194 }
1195
1196 #[test]
1197 fn test_shortest_path_same_node_returns_single_element() {
1198 let g = make_graph();
1199 add(&g, "a");
1200 let path = g
1201 .shortest_path(&EntityId::new("a"), &EntityId::new("a"))
1202 .unwrap();
1203 assert_eq!(path, Some(vec![EntityId::new("a")]));
1204 }
1205
1206 #[test]
1207 fn test_shortest_path_missing_source_returns_error() {
1208 let g = make_graph();
1209 add(&g, "b");
1210 assert!(g
1211 .shortest_path(&EntityId::new("ghost"), &EntityId::new("b"))
1212 .is_err());
1213 }
1214
1215 #[test]
1216 fn test_shortest_path_missing_target_returns_error() {
1217 let g = make_graph();
1218 add(&g, "a");
1219 assert!(g
1220 .shortest_path(&EntityId::new("a"), &EntityId::new("ghost"))
1221 .is_err());
1222 }
1223
1224 #[test]
1227 fn test_transitive_closure_includes_start() {
1228 let g = make_graph();
1229 add(&g, "a");
1230 add(&g, "b");
1231 link(&g, "a", "b");
1232 let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1233 assert!(closure.contains(&EntityId::new("a")));
1234 assert!(closure.contains(&EntityId::new("b")));
1235 }
1236
1237 #[test]
1238 fn test_transitive_closure_isolated_node_contains_only_self() {
1239 let g = make_graph();
1240 add(&g, "a");
1241 let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1242 assert_eq!(closure.len(), 1);
1243 }
1244
1245 #[test]
1248 fn test_mem_graph_error_converts_to_runtime_error() {
1249 let e = MemGraphError::EntityNotFound("x".into());
1250 let re: AgentRuntimeError = e.into();
1251 assert!(matches!(re, AgentRuntimeError::Graph(_)));
1252 }
1253
1254 #[test]
1257 fn test_shortest_path_weighted_simple() {
1258 let g = make_graph();
1261 add(&g, "a");
1262 add(&g, "b");
1263 add(&g, "c");
1264 link_w(&g, "a", "b", 1.0);
1265 link_w(&g, "b", "c", 2.0);
1266 g.add_relationship(Relationship::new("a", "c", "DIRECT", 10.0))
1267 .unwrap();
1268
1269 let result = g
1270 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("c"))
1271 .unwrap();
1272 assert!(result.is_some());
1273 let (path, weight) = result.unwrap();
1274 assert_eq!(
1276 path,
1277 vec![EntityId::new("a"), EntityId::new("b"), EntityId::new("c")]
1278 );
1279 assert!((weight - 3.0).abs() < 1e-5);
1280 }
1281
1282 #[test]
1283 fn test_shortest_path_weighted_returns_none_for_disconnected() {
1284 let g = make_graph();
1285 add(&g, "a");
1286 add(&g, "b");
1287 let result = g
1288 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"))
1289 .unwrap();
1290 assert!(result.is_none());
1291 }
1292
1293 #[test]
1294 fn test_shortest_path_weighted_same_node() {
1295 let g = make_graph();
1296 add(&g, "a");
1297 let result = g
1298 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("a"))
1299 .unwrap();
1300 assert_eq!(result, Some((vec![EntityId::new("a")], 0.0)));
1301 }
1302
1303 #[test]
1304 fn test_shortest_path_weighted_negative_weight_errors() {
1305 let g = make_graph();
1306 add(&g, "a");
1307 add(&g, "b");
1308 g.add_relationship(Relationship::new("a", "b", "NEG", -1.0))
1309 .unwrap();
1310 let result = g.shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"));
1311 assert!(result.is_err());
1312 }
1313
1314 #[test]
1317 fn test_degree_centrality_basic() {
1318 let g = make_graph();
1322 add(&g, "a");
1323 add(&g, "b");
1324 add(&g, "c");
1325 add(&g, "d");
1326 link(&g, "a", "b");
1327 link(&g, "a", "c");
1328 link(&g, "a", "d");
1329
1330 let centrality = g.degree_centrality().unwrap();
1331 let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1332 let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1333
1334 assert!((a_cent - 1.0).abs() < 1e-5, "a centrality was {a_cent}");
1335 assert!(
1336 (b_cent - 1.0 / 3.0).abs() < 1e-5,
1337 "b centrality was {b_cent}"
1338 );
1339 }
1340
1341 #[test]
1344 fn test_betweenness_centrality_chain() {
1345 let g = make_graph();
1349 add(&g, "a");
1350 add(&g, "b");
1351 add(&g, "c");
1352 add(&g, "d");
1353 link(&g, "a", "b");
1354 link(&g, "b", "c");
1355 link(&g, "c", "d");
1356
1357 let centrality = g.betweenness_centrality().unwrap();
1358 let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1359 let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1360 let c_cent = *centrality.get(&EntityId::new("c")).unwrap();
1361 let d_cent = *centrality.get(&EntityId::new("d")).unwrap();
1362
1363 assert!((a_cent).abs() < 1e-5, "expected a_cent ~ 0, got {a_cent}");
1365 assert!(b_cent > 0.0, "expected b_cent > 0, got {b_cent}");
1366 assert!(c_cent > 0.0, "expected c_cent > 0, got {c_cent}");
1367 assert!((d_cent).abs() < 1e-5, "expected d_cent ~ 0, got {d_cent}");
1368 }
1369
1370 #[test]
1373 fn test_label_propagation_communities_two_clusters() {
1374 let g = make_graph();
1378 for id in &["a", "b", "c", "x", "y", "z"] {
1379 add(&g, id);
1380 }
1381 link(&g, "a", "b");
1383 link(&g, "b", "a");
1384 link(&g, "b", "c");
1385 link(&g, "c", "b");
1386 link(&g, "a", "c");
1387 link(&g, "c", "a");
1388 link(&g, "x", "y");
1390 link(&g, "y", "x");
1391 link(&g, "y", "z");
1392 link(&g, "z", "y");
1393 link(&g, "x", "z");
1394 link(&g, "z", "x");
1395
1396 let communities = g.label_propagation_communities(100).unwrap();
1397
1398 let label_a = communities[&EntityId::new("a")];
1399 let label_b = communities[&EntityId::new("b")];
1400 let label_c = communities[&EntityId::new("c")];
1401 let label_x = communities[&EntityId::new("x")];
1402 let label_y = communities[&EntityId::new("y")];
1403 let label_z = communities[&EntityId::new("z")];
1404
1405 assert_eq!(label_a, label_b, "a and b should be in same community");
1408 assert_eq!(label_b, label_c, "b and c should be in same community");
1409 assert_eq!(label_x, label_y, "x and y should be in same community");
1410 assert_eq!(label_y, label_z, "y and z should be in same community");
1411 assert_ne!(
1412 label_a, label_x,
1413 "cluster 1 and cluster 2 should be different communities"
1414 );
1415 }
1416
1417 #[test]
1420 fn test_subgraph_extracts_correct_nodes_and_edges() {
1421 let g = make_graph();
1424 add(&g, "a");
1425 add(&g, "b");
1426 add(&g, "c");
1427 add(&g, "d");
1428 link(&g, "a", "b");
1429 link(&g, "b", "c");
1430 link(&g, "c", "d");
1431
1432 let sub = g
1433 .subgraph(&[EntityId::new("a"), EntityId::new("b"), EntityId::new("c")])
1434 .unwrap();
1435
1436 assert_eq!(sub.entity_count().unwrap(), 3);
1437 assert_eq!(sub.relationship_count().unwrap(), 2);
1438
1439 assert!(sub.get_entity(&EntityId::new("d")).is_err());
1441
1442 let path = sub
1444 .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1445 .unwrap();
1446 assert!(path.is_some());
1447 assert_eq!(path.unwrap().len(), 3);
1448 }
1449
1450 #[test]
1453 fn test_detect_cycles_dag_returns_false() {
1454 let g = make_graph();
1455 add(&g, "a");
1456 add(&g, "b");
1457 add(&g, "c");
1458 link(&g, "a", "b");
1459 link(&g, "b", "c");
1460 assert_eq!(g.detect_cycles().unwrap(), false);
1461 }
1462
1463 #[test]
1464 fn test_detect_cycles_self_loop_returns_true() {
1465 let g = make_graph();
1466 add(&g, "a");
1467 g.add_relationship(Relationship::new("a", "a", "SELF", 1.0))
1469 .unwrap();
1470 assert_eq!(g.detect_cycles().unwrap(), true);
1471 }
1472
1473 #[test]
1474 fn test_detect_cycles_simple_cycle_returns_true() {
1475 let g = make_graph();
1476 add(&g, "a");
1477 add(&g, "b");
1478 link(&g, "a", "b");
1479 g.add_relationship(Relationship::new("b", "a", "BACK", 1.0))
1480 .unwrap();
1481 assert_eq!(g.detect_cycles().unwrap(), true);
1482 }
1483
1484 #[test]
1485 fn test_detect_cycles_empty_graph_returns_false() {
1486 let g = make_graph();
1487 assert_eq!(g.detect_cycles().unwrap(), false);
1488 }
1489
1490 #[test]
1491 fn test_detect_cycles_result_is_cached() {
1492 let g = make_graph();
1493 add(&g, "x");
1494 add(&g, "y");
1495 link(&g, "x", "y");
1496 let r1 = g.detect_cycles().unwrap();
1498 let r2 = g.detect_cycles().unwrap();
1500 assert_eq!(r1, r2);
1501 }
1502
1503 #[test]
1504 fn test_detect_cycles_cache_invalidated_on_mutation() {
1505 let g = make_graph();
1506 add(&g, "a");
1507 add(&g, "b");
1508 link(&g, "a", "b");
1509 assert_eq!(g.detect_cycles().unwrap(), false);
1510
1511 g.add_relationship(Relationship::new("b", "a", "BACK", 1.0))
1513 .unwrap();
1514 assert_eq!(
1515 g.detect_cycles().unwrap(),
1516 true,
1517 "cache should be invalidated after adding a back edge"
1518 );
1519 }
1520
1521 #[test]
1524 fn test_bfs_bounded_respects_max_depth() {
1525 let g = make_graph();
1527 add(&g, "a");
1528 add(&g, "b");
1529 add(&g, "c");
1530 add(&g, "d");
1531 link(&g, "a", "b");
1532 link(&g, "b", "c");
1533 link(&g, "c", "d");
1534
1535 let visited = g.bfs_bounded("a", 1, 100).unwrap();
1537 assert!(visited.contains(&EntityId::new("a")));
1538 assert!(visited.contains(&EntityId::new("b")));
1539 assert!(!visited.contains(&EntityId::new("c")), "c is at depth 2, should not be visited");
1540 }
1541
1542 #[test]
1545 fn test_path_exists_returns_true() {
1546 let g = make_graph();
1547 add(&g, "a");
1548 add(&g, "b");
1549 add(&g, "c");
1550 link(&g, "a", "b");
1551 link(&g, "b", "c");
1552 assert_eq!(g.path_exists("a", "c").unwrap(), true);
1553 }
1554
1555 #[test]
1556 fn test_path_exists_returns_false() {
1557 let g = make_graph();
1558 add(&g, "a");
1559 add(&g, "b");
1560 assert_eq!(g.path_exists("a", "b").unwrap(), false);
1561 }
1562
1563 #[test]
1566 fn test_entity_id_as_str() {
1567 let id = EntityId::new("my-entity");
1568 assert_eq!(id.as_str(), "my-entity");
1569 }
1570
1571 #[test]
1574 fn test_entity_id_as_ref_str() {
1575 let id = EntityId::new("asref-test");
1576 let s: &str = id.as_ref();
1577 assert_eq!(s, "asref-test");
1578 }
1579
1580 #[test]
1581 fn test_dfs_bounded_respects_max_nodes() {
1582 let g = make_graph();
1584 add(&g, "a");
1585 add(&g, "b");
1586 add(&g, "c");
1587 add(&g, "d");
1588 link(&g, "a", "b");
1589 link(&g, "b", "c");
1590 link(&g, "c", "d");
1591
1592 let visited = g.dfs_bounded("a", 100, 2).unwrap();
1594 assert_eq!(visited.len(), 2, "should stop at 2 nodes");
1595 }
1596}