1use crate::error::AgentRuntimeError;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20use std::sync::{Arc, Mutex};
21
22fn recover_lock<'a, T>(
28 result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
29 ctx: &str,
30) -> std::sync::MutexGuard<'a, T>
31where
32 T: ?Sized,
33{
34 match result {
35 Ok(guard) => guard,
36 Err(poisoned) => {
37 tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
38 poisoned.into_inner()
39 }
40 }
41}
42
43#[derive(Debug, Clone, Copy, PartialEq)]
48struct OrdF32(f32);
49
50impl Eq for OrdF32 {}
51
52impl PartialOrd for OrdF32 {
53 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58impl Ord for OrdF32 {
59 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
60 self.0
61 .partial_cmp(&other.0)
62 .unwrap_or(std::cmp::Ordering::Greater)
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
70pub struct EntityId(pub String);
71
72impl EntityId {
73 pub fn new(id: impl Into<String>) -> Self {
75 Self(id.into())
76 }
77}
78
79impl std::fmt::Display for EntityId {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "{}", self.0)
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Entity {
90 pub id: EntityId,
92 pub label: String,
94 pub properties: HashMap<String, Value>,
96}
97
98impl Entity {
99 pub fn new(id: impl Into<String>, label: impl Into<String>) -> Self {
101 Self {
102 id: EntityId::new(id),
103 label: label.into(),
104 properties: HashMap::new(),
105 }
106 }
107
108 pub fn with_properties(
110 id: impl Into<String>,
111 label: impl Into<String>,
112 properties: HashMap<String, Value>,
113 ) -> Self {
114 Self {
115 id: EntityId::new(id),
116 label: label.into(),
117 properties,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct Relationship {
127 pub from: EntityId,
129 pub to: EntityId,
131 pub kind: String,
133 pub weight: f32,
135}
136
137impl Relationship {
138 pub fn new(
140 from: impl Into<String>,
141 to: impl Into<String>,
142 kind: impl Into<String>,
143 weight: f32,
144 ) -> Self {
145 Self {
146 from: EntityId::new(from),
147 to: EntityId::new(to),
148 kind: kind.into(),
149 weight,
150 }
151 }
152}
153
154#[derive(Debug, thiserror::Error)]
158pub enum MemGraphError {
159 #[error("Entity '{0}' not found")]
161 EntityNotFound(String),
162
163 #[error("Relationship '{kind}' from '{from}' to '{to}' already exists")]
165 DuplicateRelationship {
166 from: String,
168 to: String,
170 kind: String,
172 },
173
174 #[error("Graph internal error: {0}")]
176 Internal(String),
177}
178
179impl From<MemGraphError> for AgentRuntimeError {
180 fn from(e: MemGraphError) -> Self {
181 AgentRuntimeError::Graph(e.to_string())
182 }
183}
184
185#[derive(Debug, Clone)]
196pub struct GraphStore {
197 inner: Arc<Mutex<GraphInner>>,
198}
199
200#[derive(Debug)]
201struct GraphInner {
202 entities: HashMap<EntityId, Entity>,
203 relationships: Vec<Relationship>,
204}
205
206impl GraphStore {
207 pub fn new() -> Self {
209 Self {
210 inner: Arc::new(Mutex::new(GraphInner {
211 entities: HashMap::new(),
212 relationships: Vec::new(),
213 })),
214 }
215 }
216
217 pub fn add_entity(&self, entity: Entity) -> Result<(), AgentRuntimeError> {
221 let mut inner = recover_lock(self.inner.lock(), "add_entity");
222 inner.entities.insert(entity.id.clone(), entity);
223 Ok(())
224 }
225
226 pub fn get_entity(&self, id: &EntityId) -> Result<Entity, AgentRuntimeError> {
228 let inner = recover_lock(self.inner.lock(), "get_entity");
229 inner
230 .entities
231 .get(id)
232 .cloned()
233 .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))
234 }
235
236 pub fn add_relationship(&self, rel: Relationship) -> Result<(), AgentRuntimeError> {
240 let mut inner = recover_lock(self.inner.lock(), "add_relationship");
241
242 if !inner.entities.contains_key(&rel.from) {
243 return Err(AgentRuntimeError::Graph(format!(
244 "source entity '{}' not found",
245 rel.from.0
246 )));
247 }
248 if !inner.entities.contains_key(&rel.to) {
249 return Err(AgentRuntimeError::Graph(format!(
250 "target entity '{}' not found",
251 rel.to.0
252 )));
253 }
254
255 let duplicate = inner
259 .relationships
260 .iter()
261 .any(|r| r.from == rel.from && r.to == rel.to && r.kind == rel.kind);
262 if duplicate {
263 return Err(AgentRuntimeError::Graph(
264 MemGraphError::DuplicateRelationship {
265 from: rel.from.0.clone(),
266 to: rel.to.0.clone(),
267 kind: rel.kind.clone(),
268 }
269 .to_string(),
270 ));
271 }
272
273 inner.relationships.push(rel);
274 Ok(())
275 }
276
277 pub fn remove_entity(&self, id: &EntityId) -> Result<(), AgentRuntimeError> {
279 let mut inner = recover_lock(self.inner.lock(), "remove_entity");
280
281 if inner.entities.remove(id).is_none() {
282 return Err(AgentRuntimeError::Graph(format!(
283 "entity '{}' not found",
284 id.0
285 )));
286 }
287 inner.relationships.retain(|r| &r.from != id && &r.to != id);
288 Ok(())
289 }
290
291 fn neighbours(relationships: &[Relationship], id: &EntityId) -> Vec<EntityId> {
293 relationships
294 .iter()
295 .filter(|r| &r.from == id)
296 .map(|r| r.to.clone())
297 .collect()
298 }
299
300 #[tracing::instrument(skip(self))]
304 pub fn bfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
305 let inner = recover_lock(self.inner.lock(), "bfs");
306
307 if !inner.entities.contains_key(start) {
308 return Err(AgentRuntimeError::Graph(format!(
309 "start entity '{}' not found",
310 start.0
311 )));
312 }
313
314 let mut visited: HashSet<EntityId> = HashSet::new();
315 let mut queue: VecDeque<EntityId> = VecDeque::new();
316 let mut result: Vec<EntityId> = Vec::new();
317
318 visited.insert(start.clone());
319 queue.push_back(start.clone());
320
321 while let Some(current) = queue.pop_front() {
322 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
323 for neighbour in neighbours {
324 if visited.insert(neighbour.clone()) {
325 result.push(neighbour.clone());
326 queue.push_back(neighbour);
327 }
328 }
329 }
330
331 tracing::debug!("BFS visited {} nodes", result.len());
332 Ok(result)
333 }
334
335 #[tracing::instrument(skip(self))]
339 pub fn dfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
340 let inner = recover_lock(self.inner.lock(), "dfs");
341
342 if !inner.entities.contains_key(start) {
343 return Err(AgentRuntimeError::Graph(format!(
344 "start entity '{}' not found",
345 start.0
346 )));
347 }
348
349 let mut visited: HashSet<EntityId> = HashSet::new();
350 let mut stack: Vec<EntityId> = Vec::new();
351 let mut result: Vec<EntityId> = Vec::new();
352
353 visited.insert(start.clone());
354 stack.push(start.clone());
355
356 while let Some(current) = stack.pop() {
357 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
358 for neighbour in neighbours {
359 if visited.insert(neighbour.clone()) {
360 result.push(neighbour.clone());
361 stack.push(neighbour);
362 }
363 }
364 }
365
366 tracing::debug!("DFS visited {} nodes", result.len());
367 Ok(result)
368 }
369
370 #[tracing::instrument(skip(self))]
376 pub fn shortest_path(
377 &self,
378 from: &EntityId,
379 to: &EntityId,
380 ) -> Result<Option<Vec<EntityId>>, AgentRuntimeError> {
381 let inner = recover_lock(self.inner.lock(), "shortest_path");
382
383 if !inner.entities.contains_key(from) {
384 return Err(AgentRuntimeError::Graph(format!(
385 "source entity '{}' not found",
386 from.0
387 )));
388 }
389 if !inner.entities.contains_key(to) {
390 return Err(AgentRuntimeError::Graph(format!(
391 "target entity '{}' not found",
392 to.0
393 )));
394 }
395
396 if from == to {
397 return Ok(Some(vec![from.clone()]));
398 }
399
400 let mut visited: HashSet<EntityId> = HashSet::new();
401 let mut queue: VecDeque<Vec<EntityId>> = VecDeque::new();
402
403 visited.insert(from.clone());
404 queue.push_back(vec![from.clone()]);
405
406 while let Some(path) = queue.pop_front() {
407 let current = match path.last() {
408 Some(c) => c.clone(),
409 None => continue,
410 };
411
412 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
413
414 for neighbour in neighbours {
415 if &neighbour == to {
416 let mut full_path = path.clone();
417 full_path.push(neighbour);
418 return Ok(Some(full_path));
419 }
420 if visited.insert(neighbour.clone()) {
421 let mut new_path = path.clone();
422 new_path.push(neighbour.clone());
423 queue.push_back(new_path);
424 }
425 }
426 }
427
428 Ok(None)
429 }
430
431 pub fn shortest_path_weighted(
441 &self,
442 from: &EntityId,
443 to: &EntityId,
444 ) -> Result<Option<(Vec<EntityId>, f32)>, AgentRuntimeError> {
445 let inner = recover_lock(self.inner.lock(), "shortest_path_weighted");
446
447 if !inner.entities.contains_key(from) {
448 return Err(AgentRuntimeError::Graph(format!(
449 "source entity '{}' not found",
450 from.0
451 )));
452 }
453 if !inner.entities.contains_key(to) {
454 return Err(AgentRuntimeError::Graph(format!(
455 "target entity '{}' not found",
456 to.0
457 )));
458 }
459
460 for rel in &inner.relationships {
462 if rel.weight < 0.0 {
463 return Err(AgentRuntimeError::Graph(format!(
464 "negative weight {:.4} on edge '{}' -> '{}'",
465 rel.weight, rel.from.0, rel.to.0
466 )));
467 }
468 }
469
470 if from == to {
471 return Ok(Some((vec![from.clone()], 0.0)));
472 }
473
474 let mut dist: HashMap<EntityId, f32> = HashMap::new();
477 let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
478 let mut heap: BinaryHeap<(OrdF32, EntityId)> = BinaryHeap::new();
480
481 dist.insert(from.clone(), 0.0);
482 heap.push((OrdF32(-0.0), from.clone()));
483
484 while let Some((OrdF32(neg_cost), current)) = heap.pop() {
485 let cost = -neg_cost;
486
487 if let Some(&best) = dist.get(¤t) {
489 if cost > best {
490 continue;
491 }
492 }
493
494 if ¤t == to {
495 let mut path = vec![to.clone()];
497 let mut node = to.clone();
498 while let Some(p) = prev.get(&node) {
499 path.push(p.clone());
500 node = p.clone();
501 }
502 path.reverse();
503 return Ok(Some((path, cost)));
504 }
505
506 for rel in inner.relationships.iter().filter(|r| &r.from == ¤t) {
507 let next_cost = cost + rel.weight;
508 let entry = dist.entry(rel.to.clone()).or_insert(f32::INFINITY);
509 if next_cost < *entry {
510 *entry = next_cost;
511 prev.insert(rel.to.clone(), current.clone());
512 heap.push((OrdF32(-next_cost), rel.to.clone()));
513 }
514 }
515 }
516
517 Ok(None)
518 }
519
520 pub fn transitive_closure(
522 &self,
523 start: &EntityId,
524 ) -> Result<HashSet<EntityId>, AgentRuntimeError> {
525 let reachable = self.bfs(start)?;
526 let mut set: HashSet<EntityId> = reachable.into_iter().collect();
527 set.insert(start.clone());
528 Ok(set)
529 }
530
531 pub fn entity_count(&self) -> Result<usize, AgentRuntimeError> {
533 let inner = recover_lock(self.inner.lock(), "entity_count");
534 Ok(inner.entities.len())
535 }
536
537 pub fn relationship_count(&self) -> Result<usize, AgentRuntimeError> {
539 let inner = recover_lock(self.inner.lock(), "relationship_count");
540 Ok(inner.relationships.len())
541 }
542
543 pub fn degree_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
546 let inner = recover_lock(self.inner.lock(), "degree_centrality");
547 let n = inner.entities.len();
548
549 let mut out_degree: HashMap<EntityId, usize> = HashMap::new();
550 let mut in_degree: HashMap<EntityId, usize> = HashMap::new();
551
552 for id in inner.entities.keys() {
553 out_degree.insert(id.clone(), 0);
554 in_degree.insert(id.clone(), 0);
555 }
556
557 for rel in &inner.relationships {
558 *out_degree.entry(rel.from.clone()).or_insert(0) += 1;
559 *in_degree.entry(rel.to.clone()).or_insert(0) += 1;
560 }
561
562 let denom = if n <= 1 { 1.0 } else { (n - 1) as f32 };
563 let mut result = HashMap::new();
564 for id in inner.entities.keys() {
565 let od = *out_degree.get(id).unwrap_or(&0);
566 let id_ = *in_degree.get(id).unwrap_or(&0);
567 let centrality = if n <= 1 {
568 0.0
569 } else {
570 (od + id_) as f32 / denom
571 };
572 result.insert(id.clone(), centrality);
573 }
574
575 Ok(result)
576 }
577
578 pub fn betweenness_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
584 let inner = recover_lock(self.inner.lock(), "betweenness_centrality");
585 let n = inner.entities.len();
586 let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
587
588 let mut centrality: HashMap<EntityId, f32> =
589 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
590
591 for source in &nodes {
592 let mut stack: Vec<EntityId> = Vec::new();
594 let mut predecessors: HashMap<EntityId, Vec<EntityId>> =
595 nodes.iter().map(|id| (id.clone(), vec![])).collect();
596 let mut sigma: HashMap<EntityId, f32> =
597 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
598 let mut dist: HashMap<EntityId, i64> =
599 nodes.iter().map(|id| (id.clone(), -1i64)).collect();
600
601 *sigma.entry(source.clone()).or_insert(0.0) = 1.0;
602 *dist.entry(source.clone()).or_insert(-1) = 0;
603
604 let mut queue: VecDeque<EntityId> = VecDeque::new();
605 queue.push_back(source.clone());
606
607 while let Some(v) = queue.pop_front() {
608 stack.push(v.clone());
609 let d_v = *dist.get(&v).unwrap_or(&0);
610 let sigma_v = *sigma.get(&v).unwrap_or(&0.0);
611 for rel in inner.relationships.iter().filter(|r| &r.from == &v) {
612 let w = &rel.to;
613 let d_w = dist.get(w).copied().unwrap_or(-1);
614 if d_w < 0 {
615 queue.push_back(w.clone());
616 *dist.entry(w.clone()).or_insert(-1) = d_v + 1;
617 }
618 if dist.get(w).copied().unwrap_or(-1) == d_v + 1 {
619 *sigma.entry(w.clone()).or_insert(0.0) += sigma_v;
620 predecessors.entry(w.clone()).or_default().push(v.clone());
621 }
622 }
623 }
624
625 let mut delta: HashMap<EntityId, f32> =
626 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
627
628 while let Some(w) = stack.pop() {
629 let delta_w = *delta.get(&w).unwrap_or(&0.0);
630 let sigma_w = *sigma.get(&w).unwrap_or(&1.0);
631 for v in predecessors.get(&w).cloned().unwrap_or_default() {
632 let sigma_v = *sigma.get(&v).unwrap_or(&1.0);
633 let contribution = (sigma_v / sigma_w) * (1.0 + delta_w);
634 *delta.entry(v.clone()).or_insert(0.0) += contribution;
635 }
636 if &w != source {
637 *centrality.entry(w.clone()).or_insert(0.0) += delta_w;
638 }
639 }
640 }
641
642 if n > 2 {
644 let norm = 2.0 / (((n - 1) * (n - 2)) as f32);
645 for v in centrality.values_mut() {
646 *v *= norm;
647 }
648 } else {
649 for v in centrality.values_mut() {
650 *v = 0.0;
651 }
652 }
653
654 Ok(centrality)
655 }
656
657 pub fn label_propagation_communities(
662 &self,
663 max_iterations: usize,
664 ) -> Result<HashMap<EntityId, usize>, AgentRuntimeError> {
665 let inner = recover_lock(self.inner.lock(), "label_propagation_communities");
666 let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
667
668 let mut labels: HashMap<EntityId, usize> = nodes
670 .iter()
671 .enumerate()
672 .map(|(i, id)| (id.clone(), i))
673 .collect();
674
675 for _ in 0..max_iterations {
676 let mut changed = false;
677 for node in &nodes {
679 let neighbour_labels: Vec<usize> = inner
681 .relationships
682 .iter()
683 .filter(|r| &r.from == node || &r.to == node)
684 .map(|r| {
685 if &r.from == node {
686 labels.get(&r.to).copied().unwrap_or(0)
687 } else {
688 labels.get(&r.from).copied().unwrap_or(0)
689 }
690 })
691 .collect();
692
693 if neighbour_labels.is_empty() {
694 continue;
695 }
696
697 let mut freq: HashMap<usize, usize> = HashMap::new();
699 for &lbl in &neighbour_labels {
700 *freq.entry(lbl).or_insert(0) += 1;
701 }
702 let best = freq
703 .into_iter()
704 .max_by_key(|&(_, count)| count)
705 .map(|(lbl, _)| lbl);
706
707 if let Some(new_label) = best {
708 let current = labels.entry(node.clone()).or_insert(0);
709 if *current != new_label {
710 *current = new_label;
711 changed = true;
712 }
713 }
714 }
715
716 if !changed {
717 break;
718 }
719 }
720
721 Ok(labels)
722 }
723
724 pub fn subgraph(&self, node_ids: &[EntityId]) -> Result<GraphStore, AgentRuntimeError> {
727 let inner = recover_lock(self.inner.lock(), "subgraph");
728 let id_set: HashSet<&EntityId> = node_ids.iter().collect();
729
730 let new_store = GraphStore::new();
731
732 for id in node_ids {
733 let entity = inner
734 .entities
735 .get(id)
736 .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))?
737 .clone();
738 let mut new_inner = recover_lock(new_store.inner.lock(), "subgraph:add_entity");
740 new_inner.entities.insert(entity.id.clone(), entity);
741 }
742
743 for rel in inner.relationships.iter() {
744 if id_set.contains(&rel.from) && id_set.contains(&rel.to) {
745 let mut new_inner =
746 recover_lock(new_store.inner.lock(), "subgraph:add_relationship");
747 new_inner.relationships.push(rel.clone());
748 }
749 }
750
751 Ok(new_store)
752 }
753}
754
755impl Default for GraphStore {
756 fn default() -> Self {
757 Self::new()
758 }
759}
760
761#[cfg(test)]
764mod tests {
765 use super::*;
766
767 fn make_graph() -> GraphStore {
768 GraphStore::new()
769 }
770
771 fn add(g: &GraphStore, id: &str) {
772 g.add_entity(Entity::new(id, "Node")).unwrap();
773 }
774
775 fn link(g: &GraphStore, from: &str, to: &str) {
776 g.add_relationship(Relationship::new(from, to, "CONNECTS", 1.0))
777 .unwrap();
778 }
779
780 fn link_w(g: &GraphStore, from: &str, to: &str, weight: f32) {
781 g.add_relationship(Relationship::new(from, to, "CONNECTS", weight))
782 .unwrap();
783 }
784
785 #[test]
788 fn test_entity_id_equality() {
789 assert_eq!(EntityId::new("a"), EntityId::new("a"));
790 assert_ne!(EntityId::new("a"), EntityId::new("b"));
791 }
792
793 #[test]
794 fn test_entity_id_display() {
795 let id = EntityId::new("hello");
796 assert_eq!(id.to_string(), "hello");
797 }
798
799 #[test]
802 fn test_entity_new_has_empty_properties() {
803 let e = Entity::new("e1", "Person");
804 assert!(e.properties.is_empty());
805 }
806
807 #[test]
808 fn test_entity_with_properties_stores_props() {
809 let mut props = HashMap::new();
810 props.insert("age".into(), Value::Number(42.into()));
811 let e = Entity::with_properties("e1", "Person", props);
812 assert!(e.properties.contains_key("age"));
813 }
814
815 #[test]
818 fn test_graph_add_entity_increments_count() {
819 let g = make_graph();
820 add(&g, "a");
821 assert_eq!(g.entity_count().unwrap(), 1);
822 }
823
824 #[test]
825 fn test_graph_get_entity_returns_entity() {
826 let g = make_graph();
827 g.add_entity(Entity::new("e1", "Person")).unwrap();
828 let e = g.get_entity(&EntityId::new("e1")).unwrap();
829 assert_eq!(e.label, "Person");
830 }
831
832 #[test]
833 fn test_graph_get_entity_missing_returns_error() {
834 let g = make_graph();
835 assert!(g.get_entity(&EntityId::new("ghost")).is_err());
836 }
837
838 #[test]
839 fn test_graph_add_relationship_increments_count() {
840 let g = make_graph();
841 add(&g, "a");
842 add(&g, "b");
843 link(&g, "a", "b");
844 assert_eq!(g.relationship_count().unwrap(), 1);
845 }
846
847 #[test]
848 fn test_graph_add_relationship_missing_source_fails() {
849 let g = make_graph();
850 add(&g, "b");
851 let result = g.add_relationship(Relationship::new("ghost", "b", "X", 1.0));
852 assert!(result.is_err());
853 }
854
855 #[test]
856 fn test_graph_add_relationship_missing_target_fails() {
857 let g = make_graph();
858 add(&g, "a");
859 let result = g.add_relationship(Relationship::new("a", "ghost", "X", 1.0));
860 assert!(result.is_err());
861 }
862
863 #[test]
864 fn test_graph_remove_entity_removes_relationships() {
865 let g = make_graph();
866 add(&g, "a");
867 add(&g, "b");
868 link(&g, "a", "b");
869 g.remove_entity(&EntityId::new("a")).unwrap();
870 assert_eq!(g.entity_count().unwrap(), 1);
871 assert_eq!(g.relationship_count().unwrap(), 0);
872 }
873
874 #[test]
875 fn test_graph_remove_entity_missing_returns_error() {
876 let g = make_graph();
877 assert!(g.remove_entity(&EntityId::new("ghost")).is_err());
878 }
879
880 #[test]
883 fn test_bfs_finds_direct_neighbours() {
884 let g = make_graph();
885 add(&g, "a");
886 add(&g, "b");
887 add(&g, "c");
888 link(&g, "a", "b");
889 link(&g, "a", "c");
890 let visited = g.bfs(&EntityId::new("a")).unwrap();
891 assert_eq!(visited.len(), 2);
892 }
893
894 #[test]
895 fn test_bfs_traverses_chain() {
896 let g = make_graph();
897 add(&g, "a");
898 add(&g, "b");
899 add(&g, "c");
900 add(&g, "d");
901 link(&g, "a", "b");
902 link(&g, "b", "c");
903 link(&g, "c", "d");
904 let visited = g.bfs(&EntityId::new("a")).unwrap();
905 assert_eq!(visited.len(), 3);
906 assert_eq!(visited[0], EntityId::new("b"));
907 }
908
909 #[test]
910 fn test_bfs_handles_isolated_node() {
911 let g = make_graph();
912 add(&g, "a");
913 let visited = g.bfs(&EntityId::new("a")).unwrap();
914 assert!(visited.is_empty());
915 }
916
917 #[test]
918 fn test_bfs_missing_start_returns_error() {
919 let g = make_graph();
920 assert!(g.bfs(&EntityId::new("ghost")).is_err());
921 }
922
923 #[test]
926 fn test_dfs_visits_all_reachable_nodes() {
927 let g = make_graph();
928 add(&g, "a");
929 add(&g, "b");
930 add(&g, "c");
931 add(&g, "d");
932 link(&g, "a", "b");
933 link(&g, "a", "c");
934 link(&g, "b", "d");
935 let visited = g.dfs(&EntityId::new("a")).unwrap();
936 assert_eq!(visited.len(), 3);
937 }
938
939 #[test]
940 fn test_dfs_handles_isolated_node() {
941 let g = make_graph();
942 add(&g, "a");
943 let visited = g.dfs(&EntityId::new("a")).unwrap();
944 assert!(visited.is_empty());
945 }
946
947 #[test]
948 fn test_dfs_missing_start_returns_error() {
949 let g = make_graph();
950 assert!(g.dfs(&EntityId::new("ghost")).is_err());
951 }
952
953 #[test]
956 fn test_shortest_path_direct_connection() {
957 let g = make_graph();
958 add(&g, "a");
959 add(&g, "b");
960 link(&g, "a", "b");
961 let path = g
962 .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
963 .unwrap();
964 assert_eq!(path, Some(vec![EntityId::new("a"), EntityId::new("b")]));
965 }
966
967 #[test]
968 fn test_shortest_path_multi_hop() {
969 let g = make_graph();
970 add(&g, "a");
971 add(&g, "b");
972 add(&g, "c");
973 link(&g, "a", "b");
974 link(&g, "b", "c");
975 let path = g
976 .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
977 .unwrap();
978 assert_eq!(path.as_ref().map(|p| p.len()), Some(3));
979 }
980
981 #[test]
982 fn test_shortest_path_returns_none_for_disconnected() {
983 let g = make_graph();
984 add(&g, "a");
985 add(&g, "b");
986 let path = g
987 .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
988 .unwrap();
989 assert_eq!(path, None);
990 }
991
992 #[test]
993 fn test_shortest_path_same_node_returns_single_element() {
994 let g = make_graph();
995 add(&g, "a");
996 let path = g
997 .shortest_path(&EntityId::new("a"), &EntityId::new("a"))
998 .unwrap();
999 assert_eq!(path, Some(vec![EntityId::new("a")]));
1000 }
1001
1002 #[test]
1003 fn test_shortest_path_missing_source_returns_error() {
1004 let g = make_graph();
1005 add(&g, "b");
1006 assert!(g
1007 .shortest_path(&EntityId::new("ghost"), &EntityId::new("b"))
1008 .is_err());
1009 }
1010
1011 #[test]
1012 fn test_shortest_path_missing_target_returns_error() {
1013 let g = make_graph();
1014 add(&g, "a");
1015 assert!(g
1016 .shortest_path(&EntityId::new("a"), &EntityId::new("ghost"))
1017 .is_err());
1018 }
1019
1020 #[test]
1023 fn test_transitive_closure_includes_start() {
1024 let g = make_graph();
1025 add(&g, "a");
1026 add(&g, "b");
1027 link(&g, "a", "b");
1028 let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1029 assert!(closure.contains(&EntityId::new("a")));
1030 assert!(closure.contains(&EntityId::new("b")));
1031 }
1032
1033 #[test]
1034 fn test_transitive_closure_isolated_node_contains_only_self() {
1035 let g = make_graph();
1036 add(&g, "a");
1037 let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1038 assert_eq!(closure.len(), 1);
1039 }
1040
1041 #[test]
1044 fn test_mem_graph_error_converts_to_runtime_error() {
1045 let e = MemGraphError::EntityNotFound("x".into());
1046 let re: AgentRuntimeError = e.into();
1047 assert!(matches!(re, AgentRuntimeError::Graph(_)));
1048 }
1049
1050 #[test]
1053 fn test_shortest_path_weighted_simple() {
1054 let g = make_graph();
1057 add(&g, "a");
1058 add(&g, "b");
1059 add(&g, "c");
1060 link_w(&g, "a", "b", 1.0);
1061 link_w(&g, "b", "c", 2.0);
1062 g.add_relationship(Relationship::new("a", "c", "DIRECT", 10.0))
1063 .unwrap();
1064
1065 let result = g
1066 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("c"))
1067 .unwrap();
1068 assert!(result.is_some());
1069 let (path, weight) = result.unwrap();
1070 assert_eq!(
1072 path,
1073 vec![EntityId::new("a"), EntityId::new("b"), EntityId::new("c")]
1074 );
1075 assert!((weight - 3.0).abs() < 1e-5);
1076 }
1077
1078 #[test]
1079 fn test_shortest_path_weighted_returns_none_for_disconnected() {
1080 let g = make_graph();
1081 add(&g, "a");
1082 add(&g, "b");
1083 let result = g
1084 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"))
1085 .unwrap();
1086 assert!(result.is_none());
1087 }
1088
1089 #[test]
1090 fn test_shortest_path_weighted_same_node() {
1091 let g = make_graph();
1092 add(&g, "a");
1093 let result = g
1094 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("a"))
1095 .unwrap();
1096 assert_eq!(result, Some((vec![EntityId::new("a")], 0.0)));
1097 }
1098
1099 #[test]
1100 fn test_shortest_path_weighted_negative_weight_errors() {
1101 let g = make_graph();
1102 add(&g, "a");
1103 add(&g, "b");
1104 g.add_relationship(Relationship::new("a", "b", "NEG", -1.0))
1105 .unwrap();
1106 let result = g.shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"));
1107 assert!(result.is_err());
1108 }
1109
1110 #[test]
1113 fn test_degree_centrality_basic() {
1114 let g = make_graph();
1118 add(&g, "a");
1119 add(&g, "b");
1120 add(&g, "c");
1121 add(&g, "d");
1122 link(&g, "a", "b");
1123 link(&g, "a", "c");
1124 link(&g, "a", "d");
1125
1126 let centrality = g.degree_centrality().unwrap();
1127 let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1128 let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1129
1130 assert!((a_cent - 1.0).abs() < 1e-5, "a centrality was {a_cent}");
1131 assert!(
1132 (b_cent - 1.0 / 3.0).abs() < 1e-5,
1133 "b centrality was {b_cent}"
1134 );
1135 }
1136
1137 #[test]
1140 fn test_betweenness_centrality_chain() {
1141 let g = make_graph();
1145 add(&g, "a");
1146 add(&g, "b");
1147 add(&g, "c");
1148 add(&g, "d");
1149 link(&g, "a", "b");
1150 link(&g, "b", "c");
1151 link(&g, "c", "d");
1152
1153 let centrality = g.betweenness_centrality().unwrap();
1154 let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1155 let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1156 let c_cent = *centrality.get(&EntityId::new("c")).unwrap();
1157 let d_cent = *centrality.get(&EntityId::new("d")).unwrap();
1158
1159 assert!((a_cent).abs() < 1e-5, "expected a_cent ~ 0, got {a_cent}");
1161 assert!(b_cent > 0.0, "expected b_cent > 0, got {b_cent}");
1162 assert!(c_cent > 0.0, "expected c_cent > 0, got {c_cent}");
1163 assert!((d_cent).abs() < 1e-5, "expected d_cent ~ 0, got {d_cent}");
1164 }
1165
1166 #[test]
1169 fn test_label_propagation_communities_two_clusters() {
1170 let g = make_graph();
1174 for id in &["a", "b", "c", "x", "y", "z"] {
1175 add(&g, id);
1176 }
1177 link(&g, "a", "b");
1179 link(&g, "b", "a");
1180 link(&g, "b", "c");
1181 link(&g, "c", "b");
1182 link(&g, "a", "c");
1183 link(&g, "c", "a");
1184 link(&g, "x", "y");
1186 link(&g, "y", "x");
1187 link(&g, "y", "z");
1188 link(&g, "z", "y");
1189 link(&g, "x", "z");
1190 link(&g, "z", "x");
1191
1192 let communities = g.label_propagation_communities(100).unwrap();
1193
1194 let label_a = communities[&EntityId::new("a")];
1195 let label_b = communities[&EntityId::new("b")];
1196 let label_c = communities[&EntityId::new("c")];
1197 let label_x = communities[&EntityId::new("x")];
1198 let label_y = communities[&EntityId::new("y")];
1199 let label_z = communities[&EntityId::new("z")];
1200
1201 assert_eq!(label_a, label_b, "a and b should be in same community");
1204 assert_eq!(label_b, label_c, "b and c should be in same community");
1205 assert_eq!(label_x, label_y, "x and y should be in same community");
1206 assert_eq!(label_y, label_z, "y and z should be in same community");
1207 assert_ne!(
1208 label_a, label_x,
1209 "cluster 1 and cluster 2 should be different communities"
1210 );
1211 }
1212
1213 #[test]
1216 fn test_subgraph_extracts_correct_nodes_and_edges() {
1217 let g = make_graph();
1220 add(&g, "a");
1221 add(&g, "b");
1222 add(&g, "c");
1223 add(&g, "d");
1224 link(&g, "a", "b");
1225 link(&g, "b", "c");
1226 link(&g, "c", "d");
1227
1228 let sub = g
1229 .subgraph(&[EntityId::new("a"), EntityId::new("b"), EntityId::new("c")])
1230 .unwrap();
1231
1232 assert_eq!(sub.entity_count().unwrap(), 3);
1233 assert_eq!(sub.relationship_count().unwrap(), 2);
1234
1235 assert!(sub.get_entity(&EntityId::new("d")).is_err());
1237
1238 let path = sub
1240 .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1241 .unwrap();
1242 assert!(path.is_some());
1243 assert_eq!(path.unwrap().len(), 3);
1244 }
1245}